325 lines
12 KiB
Rust
325 lines
12 KiB
Rust
use axum::{Json, extract::Extension, http::StatusCode};
|
|
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
|
|
use serde_json::{Value, json};
|
|
use shared_contracts::creation_agent_document_input::{
|
|
CreationAgentDocumentInputPayload, ParseCreationAgentDocumentInputRequest,
|
|
ParseCreationAgentDocumentInputResponse,
|
|
};
|
|
|
|
use crate::{
|
|
api_response::json_success_body, http_error::AppError, request_context::RequestContext,
|
|
};
|
|
|
|
const MAX_DOCUMENT_INPUT_BYTES: usize = 256 * 1024;
|
|
const MAX_DOCUMENT_INPUT_BASE64_CHARS: usize = 360 * 1024;
|
|
const SUPPORTED_DOCUMENT_EXTENSIONS: &[&str] = &["txt", "md", "markdown", "csv", "json"];
|
|
|
|
pub async fn parse_creation_agent_document_input(
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Json(payload): Json<ParseCreationAgentDocumentInputRequest>,
|
|
) -> Result<Json<Value>, AppError> {
|
|
let file_name = normalize_file_name(&payload.file_name)?;
|
|
ensure_supported_extension(&file_name)?;
|
|
let content_base64 = payload.content_base64.trim();
|
|
if content_base64.len() > MAX_DOCUMENT_INPUT_BASE64_CHARS {
|
|
return Err(
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"message": "文档过大,请上传 256KB 以内的文本文件。",
|
|
"field": "contentBase64",
|
|
"maxSizeBytes": MAX_DOCUMENT_INPUT_BYTES,
|
|
})),
|
|
);
|
|
}
|
|
|
|
let decoded = BASE64_STANDARD.decode(content_base64).map_err(|_| {
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"message": "文档内容编码无效,请重新选择文件。",
|
|
"field": "contentBase64",
|
|
}))
|
|
})?;
|
|
|
|
if decoded.is_empty() {
|
|
return Err(
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"message": "文档内容为空,请选择有内容的文件。",
|
|
"field": "contentBase64",
|
|
})),
|
|
);
|
|
}
|
|
|
|
if decoded.len() > MAX_DOCUMENT_INPUT_BYTES {
|
|
return Err(
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"message": "文档过大,请上传 256KB 以内的文本文件。",
|
|
"field": "contentBase64",
|
|
"maxSizeBytes": MAX_DOCUMENT_INPUT_BYTES,
|
|
"actualSizeBytes": decoded.len(),
|
|
})),
|
|
);
|
|
}
|
|
|
|
let text = String::from_utf8(decoded.clone()).map_err(|_| {
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"message": "暂时只支持 UTF-8 文本文档,请转换编码后再上传。",
|
|
"field": "contentBase64",
|
|
}))
|
|
})?;
|
|
let normalized_text = normalize_document_text(&text);
|
|
|
|
if normalized_text.trim().is_empty() {
|
|
return Err(
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"message": "文档解析后没有可用文本,请换一个文件。",
|
|
"field": "contentBase64",
|
|
})),
|
|
);
|
|
}
|
|
|
|
Ok(json_success_body(
|
|
Some(&request_context),
|
|
ParseCreationAgentDocumentInputResponse {
|
|
document: CreationAgentDocumentInputPayload {
|
|
file_name,
|
|
content_type: payload
|
|
.content_type
|
|
.as_deref()
|
|
.map(str::trim)
|
|
.filter(|value| !value.is_empty())
|
|
.map(str::to_string),
|
|
size_bytes: decoded.len(),
|
|
text: normalized_text,
|
|
},
|
|
},
|
|
))
|
|
}
|
|
|
|
fn normalize_file_name(value: &str) -> Result<String, AppError> {
|
|
let normalized = value
|
|
.trim()
|
|
.rsplit(['/', '\\'])
|
|
.next()
|
|
.unwrap_or_default()
|
|
.trim()
|
|
.to_string();
|
|
|
|
if normalized.is_empty() {
|
|
return Err(
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"message": "缺少文档文件名。",
|
|
"field": "fileName",
|
|
})),
|
|
);
|
|
}
|
|
|
|
Ok(normalized)
|
|
}
|
|
|
|
fn ensure_supported_extension(file_name: &str) -> Result<(), AppError> {
|
|
let extension = file_name
|
|
.rsplit_once('.')
|
|
.map(|(_, extension)| extension.trim().to_ascii_lowercase())
|
|
.filter(|extension| !extension.is_empty())
|
|
.ok_or_else(|| unsupported_document_error(file_name))?;
|
|
|
|
if !SUPPORTED_DOCUMENT_EXTENSIONS.contains(&extension.as_str()) {
|
|
return Err(unsupported_document_error(file_name));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn unsupported_document_error(file_name: &str) -> AppError {
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"message": "暂时只支持 txt、md、csv、json 文本文档。",
|
|
"field": "fileName",
|
|
"fileName": file_name,
|
|
"supportedExtensions": SUPPORTED_DOCUMENT_EXTENSIONS,
|
|
}))
|
|
}
|
|
|
|
fn normalize_document_text(value: &str) -> String {
|
|
value
|
|
.trim_start_matches('\u{feff}')
|
|
.replace("\r\n", "\n")
|
|
.replace('\r', "\n")
|
|
.trim()
|
|
.to_string()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use axum::{body::Body, http::Request};
|
|
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
|
|
use http_body_util::BodyExt;
|
|
use module_auth::{PhoneAuthScene, PhoneLoginInput, SendPhoneCodeInput};
|
|
use platform_auth::{
|
|
AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus, sign_access_token,
|
|
};
|
|
use serde_json::{Value, json};
|
|
use std::path::PathBuf;
|
|
use time::OffsetDateTime;
|
|
use tower::ServiceExt;
|
|
|
|
use super::MAX_DOCUMENT_INPUT_BASE64_CHARS;
|
|
use crate::{app::build_router, config::AppConfig, state::AppState};
|
|
|
|
#[tokio::test]
|
|
async fn parse_document_input_returns_text_payload() {
|
|
let state = build_test_state("ok").await;
|
|
let access_token = seed_authenticated_token(&state, "13800138110").await;
|
|
let app = build_router(state);
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/runtime/creation-agent/document-inputs/parse")
|
|
.header("authorization", format!("Bearer {access_token}"))
|
|
.header("content-type", "application/json")
|
|
.header("x-genarrative-response-envelope", "1")
|
|
.body(Body::from(
|
|
json!({
|
|
"fileName": "世界设定.md",
|
|
"contentType": "text/markdown",
|
|
"contentBase64": BASE64_STANDARD.encode("第一章\r\n潮湿的港口")
|
|
})
|
|
.to_string(),
|
|
))
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should succeed");
|
|
|
|
assert_eq!(response.status(), axum::http::StatusCode::OK);
|
|
|
|
let body = response
|
|
.into_body()
|
|
.collect()
|
|
.await
|
|
.expect("body should collect")
|
|
.to_bytes();
|
|
let payload: Value =
|
|
serde_json::from_slice(&body).expect("response body should be valid json");
|
|
|
|
assert_eq!(payload["ok"], Value::Bool(true));
|
|
assert_eq!(
|
|
payload["data"]["document"]["fileName"],
|
|
json!("世界设定.md")
|
|
);
|
|
assert_eq!(
|
|
payload["data"]["document"]["text"],
|
|
json!("第一章\n潮湿的港口")
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn parse_document_input_rejects_unsupported_extension() {
|
|
let state = build_test_state("bad-ext").await;
|
|
let access_token = seed_authenticated_token(&state, "13800138111").await;
|
|
let app = build_router(state);
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/runtime/creation-agent/document-inputs/parse")
|
|
.header("authorization", format!("Bearer {access_token}"))
|
|
.header("content-type", "application/json")
|
|
.header("x-genarrative-response-envelope", "1")
|
|
.body(Body::from(
|
|
json!({
|
|
"fileName": "世界设定.docx",
|
|
"contentBase64": BASE64_STANDARD.encode("binary")
|
|
})
|
|
.to_string(),
|
|
))
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should succeed");
|
|
|
|
assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn parse_document_input_rejects_large_base64_before_decode() {
|
|
let state = build_test_state("large-base64").await;
|
|
let access_token = seed_authenticated_token(&state, "13800138112").await;
|
|
let app = build_router(state);
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/runtime/creation-agent/document-inputs/parse")
|
|
.header("authorization", format!("Bearer {access_token}"))
|
|
.header("content-type", "application/json")
|
|
.header("x-genarrative-response-envelope", "1")
|
|
.body(Body::from(
|
|
json!({
|
|
"fileName": "世界设定.txt",
|
|
"contentBase64": "A".repeat(MAX_DOCUMENT_INPUT_BASE64_CHARS + 1)
|
|
})
|
|
.to_string(),
|
|
))
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should succeed");
|
|
|
|
assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
|
|
}
|
|
|
|
async fn seed_authenticated_token(state: &AppState, phone_number: &str) -> String {
|
|
let now = OffsetDateTime::now_utc();
|
|
state
|
|
.phone_auth_service()
|
|
.send_code(
|
|
SendPhoneCodeInput {
|
|
phone_number: phone_number.to_string(),
|
|
scene: PhoneAuthScene::Login,
|
|
},
|
|
now,
|
|
)
|
|
.await
|
|
.expect("phone code should send");
|
|
let user = state
|
|
.phone_auth_service()
|
|
.login(
|
|
PhoneLoginInput {
|
|
phone_number: phone_number.to_string(),
|
|
verify_code: "123456".to_string(),
|
|
},
|
|
now + time::Duration::seconds(1),
|
|
)
|
|
.await
|
|
.expect("phone login should create user")
|
|
.user;
|
|
|
|
let claims = AccessTokenClaims::from_input(
|
|
AccessTokenClaimsInput {
|
|
user_id: user.id,
|
|
session_id: "sess_creation_doc_input".to_string(),
|
|
provider: AuthProvider::Password,
|
|
roles: vec!["user".to_string()],
|
|
token_version: user.token_version,
|
|
phone_verified: true,
|
|
binding_status: BindingStatus::Active,
|
|
display_name: Some(user.display_name),
|
|
},
|
|
state.auth_jwt_config(),
|
|
OffsetDateTime::now_utc(),
|
|
)
|
|
.expect("claims should build");
|
|
|
|
sign_access_token(&claims, state.auth_jwt_config()).expect("token should sign")
|
|
}
|
|
|
|
async fn build_test_state(label: &str) -> AppState {
|
|
let mut config = AppConfig::default();
|
|
config.auth_store_path = PathBuf::from(format!(
|
|
".codex-temp/api-server-auth-store-creation-doc-{label}.json"
|
|
));
|
|
let _ = std::fs::remove_file(&config.auth_store_path);
|
|
|
|
AppState::new(config).expect("state should build")
|
|
}
|
|
}
|