1
This commit is contained in:
@@ -90,6 +90,9 @@ pub struct AppConfig {
|
||||
pub dashscope_reference_image_model: String,
|
||||
pub dashscope_cover_image_model: String,
|
||||
pub dashscope_image_request_timeout_ms: u64,
|
||||
pub apimart_base_url: String,
|
||||
pub apimart_api_key: Option<String>,
|
||||
pub apimart_image_request_timeout_ms: u64,
|
||||
pub draft_asset_generation_max_concurrent_requests: usize,
|
||||
pub ark_character_video_base_url: String,
|
||||
pub ark_character_video_api_key: Option<String>,
|
||||
@@ -182,6 +185,9 @@ impl Default for AppConfig {
|
||||
dashscope_reference_image_model: "qwen-image-2.0".to_string(),
|
||||
dashscope_cover_image_model: "wan2.2-t2i-flash".to_string(),
|
||||
dashscope_image_request_timeout_ms: 150_000,
|
||||
apimart_base_url: "https://api.apimart.ai/v1".to_string(),
|
||||
apimart_api_key: None,
|
||||
apimart_image_request_timeout_ms: 180_000,
|
||||
draft_asset_generation_max_concurrent_requests: 4,
|
||||
ark_character_video_base_url: DEFAULT_ARK_BASE_URL.to_string(),
|
||||
ark_character_video_api_key: None,
|
||||
@@ -530,6 +536,18 @@ impl AppConfig {
|
||||
config.dashscope_image_request_timeout_ms = dashscope_image_request_timeout_ms;
|
||||
}
|
||||
|
||||
if let Some(apimart_base_url) = read_first_non_empty_env(&["APIMART_BASE_URL"]) {
|
||||
config.apimart_base_url = apimart_base_url;
|
||||
}
|
||||
|
||||
config.apimart_api_key = read_first_non_empty_env(&["APIMART_API_KEY"]);
|
||||
|
||||
if let Some(apimart_image_request_timeout_ms) =
|
||||
read_first_positive_u64_env(&["APIMART_IMAGE_REQUEST_TIMEOUT_MS"])
|
||||
{
|
||||
config.apimart_image_request_timeout_ms = apimart_image_request_timeout_ms;
|
||||
}
|
||||
|
||||
if let Some(max_concurrent_requests) = read_first_usize_env(&[
|
||||
"GENARRATIVE_DRAFT_ASSET_GENERATION_MAX_CONCURRENT_REQUESTS",
|
||||
"DRAFT_ASSET_GENERATION_MAX_CONCURRENT_REQUESTS",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use platform_llm::{LlmClient, LlmMessage, LlmStreamDelta, LlmTextRequest};
|
||||
use platform_llm::{LlmClient, LlmError, LlmMessage, LlmStreamDelta, LlmTextRequest};
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL;
|
||||
@@ -33,10 +33,63 @@ where
|
||||
{
|
||||
let llm_client =
|
||||
llm_client.ok_or_else(|| build_error(messages.model_unavailable.to_string()))?;
|
||||
let user_prompt = user_prompt.into();
|
||||
let turn_output = match request_stream_creation_agent_json_turn(
|
||||
llm_client,
|
||||
system_prompt.clone(),
|
||||
user_prompt.clone(),
|
||||
enable_web_search,
|
||||
&mut on_reply_update,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(turn_output) => Ok(turn_output),
|
||||
Err(CreationAgentJsonTurnFailure::Stream(error))
|
||||
if enable_web_search && is_web_search_tool_unavailable(&error) =>
|
||||
{
|
||||
tracing::warn!(
|
||||
error = %error,
|
||||
"创作 Agent 联网搜索插件不可用,自动降级为无联网搜索重试"
|
||||
);
|
||||
request_stream_creation_agent_json_turn(
|
||||
llm_client,
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
false,
|
||||
&mut on_reply_update,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
};
|
||||
|
||||
turn_output.map_err(|error| match error {
|
||||
CreationAgentJsonTurnFailure::Stream(_) => {
|
||||
build_error(messages.generation_failed.to_string())
|
||||
}
|
||||
CreationAgentJsonTurnFailure::Parse => build_error(messages.parse_failed.to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
enum CreationAgentJsonTurnFailure {
|
||||
Stream(LlmError),
|
||||
Parse,
|
||||
}
|
||||
|
||||
async fn request_stream_creation_agent_json_turn<F>(
|
||||
llm_client: &LlmClient,
|
||||
system_prompt: String,
|
||||
user_prompt: String,
|
||||
enable_web_search: bool,
|
||||
on_reply_update: &mut F,
|
||||
) -> Result<CreationAgentJsonTurnOutput, CreationAgentJsonTurnFailure>
|
||||
where
|
||||
F: FnMut(&str),
|
||||
{
|
||||
let mut latest_reply_text = String::new();
|
||||
let response = llm_client
|
||||
.stream_text(
|
||||
build_creation_agent_llm_request(system_prompt, user_prompt.into(), enable_web_search),
|
||||
build_creation_agent_llm_request(system_prompt, user_prompt, enable_web_search),
|
||||
|delta: &LlmStreamDelta| {
|
||||
if let Some(reply_progress) =
|
||||
extract_reply_text_from_partial_json(delta.accumulated_text.as_str())
|
||||
@@ -48,9 +101,9 @@ where
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|_| build_error(messages.generation_failed.to_string()))?;
|
||||
.map_err(CreationAgentJsonTurnFailure::Stream)?;
|
||||
let parsed = parse_json_response_text(response.content.as_str())
|
||||
.map_err(|_| build_error(messages.parse_failed.to_string()))?;
|
||||
.map_err(|_| CreationAgentJsonTurnFailure::Parse)?;
|
||||
let reply_text = read_reply_text(&parsed);
|
||||
if let Some(reply_text) = reply_text.as_deref()
|
||||
&& reply_text != latest_reply_text
|
||||
@@ -61,6 +114,13 @@ where
|
||||
Ok(CreationAgentJsonTurnOutput { parsed })
|
||||
}
|
||||
|
||||
fn is_web_search_tool_unavailable(error: &LlmError) -> bool {
|
||||
let message = error.to_string();
|
||||
message.contains("ToolNotOpen")
|
||||
|| message.contains("has not activated web search")
|
||||
|| message.contains("未开通")
|
||||
}
|
||||
|
||||
fn build_creation_agent_llm_request(
|
||||
system_prompt: String,
|
||||
user_prompt: String,
|
||||
@@ -168,11 +228,23 @@ fn read_reply_text(parsed: &JsonValue) -> Option<String> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
fs,
|
||||
io::{Read, Write},
|
||||
net::TcpListener,
|
||||
sync::{Arc, Mutex},
|
||||
thread,
|
||||
time::{Duration as StdDuration, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
use platform_llm::{LlmConfig, LlmProvider};
|
||||
|
||||
use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL;
|
||||
|
||||
use super::{
|
||||
build_creation_agent_llm_request, extract_reply_text_from_partial_json,
|
||||
parse_json_response_text,
|
||||
CreationAgentLlmTurnErrorMessages, build_creation_agent_llm_request,
|
||||
extract_reply_text_from_partial_json, is_web_search_tool_unavailable,
|
||||
parse_json_response_text, stream_creation_agent_json_turn,
|
||||
};
|
||||
|
||||
#[test]
|
||||
@@ -202,4 +274,214 @@ mod tests {
|
||||
assert_eq!(request.protocol, platform_llm::LlmTextProtocol::Responses);
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_upstream_web_search_tool_unavailable_error() {
|
||||
let error = platform_llm::LlmError::Upstream {
|
||||
status_code: 502,
|
||||
message: "Your account has not activated web search. code=ToolNotOpen".to_string(),
|
||||
};
|
||||
|
||||
assert!(is_web_search_tool_unavailable(&error));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_turn_retries_without_web_search_when_tool_is_unavailable() {
|
||||
let log_dir = std::env::temp_dir().join(format!(
|
||||
"api-server-creation-agent-raw-log-test-{}-{}",
|
||||
std::process::id(),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("system time should be after epoch")
|
||||
.as_nanos()
|
||||
));
|
||||
unsafe {
|
||||
std::env::set_var("LLM_RAW_LOG_DIR", &log_dir);
|
||||
}
|
||||
let success_json = serde_json::json!({
|
||||
"replyText": "好,我们先把玩具王国定住。",
|
||||
"progressPercent": 12,
|
||||
"nextAnchorContent": {
|
||||
"worldPromise": "玩具王国初步方向",
|
||||
"playerFantasy": null,
|
||||
"themeBoundary": null,
|
||||
"playerEntryPoint": null,
|
||||
"coreConflict": null,
|
||||
"keyRelationships": null,
|
||||
"hiddenLines": null,
|
||||
"iconicElements": null
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
let server = spawn_capturing_mock_server(vec![
|
||||
MockResponse {
|
||||
body: concat!(
|
||||
"data: {\"type\":\"error\",\"code\":\"ToolNotOpen\",\"message\":\"Your account has not activated web search.\"}\n\n",
|
||||
"data: [DONE]\n\n"
|
||||
)
|
||||
.to_string(),
|
||||
},
|
||||
MockResponse {
|
||||
body: format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::json!({
|
||||
"type": "response.output_text.delta",
|
||||
"delta": success_json
|
||||
})
|
||||
) + "data: {\"type\":\"response.completed\"}\n\n",
|
||||
},
|
||||
]);
|
||||
let config = LlmConfig::new(
|
||||
LlmProvider::Ark,
|
||||
server.base_url,
|
||||
"test-key".to_string(),
|
||||
"test-model".to_string(),
|
||||
30_000,
|
||||
0,
|
||||
1,
|
||||
)
|
||||
.expect("LLM config should build");
|
||||
let llm_client = platform_llm::LlmClient::new(config).expect("LLM client should build");
|
||||
let mut visible_replies = Vec::new();
|
||||
|
||||
let output = stream_creation_agent_json_turn(
|
||||
Some(&llm_client),
|
||||
"系统提示".to_string(),
|
||||
"用户提示",
|
||||
true,
|
||||
CreationAgentLlmTurnErrorMessages {
|
||||
model_unavailable: "模型不可用",
|
||||
generation_failed: "生成失败",
|
||||
parse_failed: "解析失败",
|
||||
},
|
||||
|text| visible_replies.push(text.to_string()),
|
||||
|message| message,
|
||||
)
|
||||
.await
|
||||
.expect("web search fallback should succeed");
|
||||
|
||||
assert_eq!(
|
||||
output.parsed["replyText"].as_str(),
|
||||
Some("好,我们先把玩具王国定住。")
|
||||
);
|
||||
assert_eq!(visible_replies, vec!["好,我们先把玩具王国定住。"]);
|
||||
|
||||
let requests = server.requests.lock().expect("requests lock").clone();
|
||||
assert_eq!(requests.len(), 2);
|
||||
assert!(requests[0].contains("\"tools\""));
|
||||
assert!(requests[0].contains("\"web_search\""));
|
||||
assert!(!requests[1].contains("\"tools\""));
|
||||
|
||||
unsafe {
|
||||
std::env::remove_var("LLM_RAW_LOG_DIR");
|
||||
}
|
||||
if log_dir.exists() {
|
||||
fs::remove_dir_all(log_dir).expect("temporary LLM raw log dir should be removed");
|
||||
}
|
||||
}
|
||||
|
||||
struct MockResponse {
|
||||
body: String,
|
||||
}
|
||||
|
||||
struct CapturingMockServer {
|
||||
base_url: String,
|
||||
requests: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
fn spawn_capturing_mock_server(responses: Vec<MockResponse>) -> CapturingMockServer {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||||
let address = listener.local_addr().expect("listener should have addr");
|
||||
let requests = Arc::new(Mutex::new(Vec::new()));
|
||||
let requests_for_thread = Arc::clone(&requests);
|
||||
|
||||
thread::spawn(move || {
|
||||
for response in responses {
|
||||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||||
let request_text = read_request(&mut stream);
|
||||
requests_for_thread
|
||||
.lock()
|
||||
.expect("requests lock")
|
||||
.push(request_text);
|
||||
write_sse_response(&mut stream, response);
|
||||
}
|
||||
});
|
||||
|
||||
CapturingMockServer {
|
||||
base_url: format!("http://{address}"),
|
||||
requests,
|
||||
}
|
||||
}
|
||||
|
||||
fn read_request(stream: &mut std::net::TcpStream) -> String {
|
||||
stream
|
||||
.set_read_timeout(Some(StdDuration::from_secs(1)))
|
||||
.expect("read timeout should be set");
|
||||
let mut buffer = Vec::new();
|
||||
let mut chunk = [0_u8; 1024];
|
||||
let mut expected_total = None;
|
||||
|
||||
loop {
|
||||
match stream.read(&mut chunk) {
|
||||
Ok(0) => break,
|
||||
Ok(bytes_read) => {
|
||||
buffer.extend_from_slice(&chunk[..bytes_read]);
|
||||
|
||||
if expected_total.is_none()
|
||||
&& let Some(header_end) = find_header_end(&buffer)
|
||||
{
|
||||
let content_length =
|
||||
read_content_length(&buffer[..header_end]).unwrap_or(0);
|
||||
expected_total = Some(header_end + content_length);
|
||||
}
|
||||
|
||||
if let Some(total_bytes) = expected_total
|
||||
&& buffer.len() >= total_bytes
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(error)
|
||||
if error.kind() == std::io::ErrorKind::WouldBlock
|
||||
|| error.kind() == std::io::ErrorKind::TimedOut =>
|
||||
{
|
||||
break;
|
||||
}
|
||||
Err(error) => panic!("mock server failed to read request: {error}"),
|
||||
}
|
||||
}
|
||||
|
||||
String::from_utf8_lossy(buffer.as_slice()).to_string()
|
||||
}
|
||||
|
||||
fn write_sse_response(stream: &mut std::net::TcpStream, response: MockResponse) {
|
||||
let raw_response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||
response.body.len(),
|
||||
response.body
|
||||
);
|
||||
|
||||
stream
|
||||
.write_all(raw_response.as_bytes())
|
||||
.expect("mock response should be written");
|
||||
stream.flush().expect("mock response should flush");
|
||||
}
|
||||
|
||||
fn find_header_end(buffer: &[u8]) -> Option<usize> {
|
||||
buffer
|
||||
.windows(4)
|
||||
.position(|window| window == b"\r\n\r\n")
|
||||
.map(|index| index + 4)
|
||||
}
|
||||
|
||||
fn read_content_length(headers: &[u8]) -> Option<usize> {
|
||||
let text = String::from_utf8_lossy(headers);
|
||||
text.lines().find_map(|line| {
|
||||
let (name, value) = line.split_once(':')?;
|
||||
if name.eq_ignore_ascii_case("content-length") {
|
||||
return value.trim().parse::<usize>().ok();
|
||||
}
|
||||
None
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,8 +98,13 @@ const PUZZLE_WORKS_PROVIDER: &str = "puzzle-works";
|
||||
const PUZZLE_GALLERY_PROVIDER: &str = "puzzle-gallery";
|
||||
const PUZZLE_RUNTIME_PROVIDER: &str = "puzzle-runtime";
|
||||
const PUZZLE_TEXT_TO_IMAGE_MODEL: &str = "wan2.2-t2i-flash";
|
||||
const PUZZLE_IMAGE_MODEL_ORIGINAL: &str = "original";
|
||||
const PUZZLE_IMAGE_MODEL_GPT_IMAGE_2: &str = "gpt-image-2";
|
||||
const PUZZLE_IMAGE_MODEL_GEMINI_31_FLASH_PREVIEW: &str = "gemini-3.1-flash-image-preview";
|
||||
const PUZZLE_ENTITY_KIND: &str = "puzzle_work";
|
||||
const PUZZLE_GENERATED_IMAGE_SIZE: &str = "1024*1024";
|
||||
const PUZZLE_APIMART_GENERATED_IMAGE_SIZE: &str = "1:1";
|
||||
const PUZZLE_APIMART_GEMINI_RESOLUTION: &str = "1K";
|
||||
|
||||
pub async fn create_puzzle_agent_session(
|
||||
State(state): State<AppState>,
|
||||
@@ -463,6 +468,7 @@ pub async fn execute_puzzle_agent_action(
|
||||
session_id = %session_id,
|
||||
owner_user_id = %owner_user_id,
|
||||
action = %action,
|
||||
image_model = payload.image_model.as_deref().unwrap_or(PUZZLE_IMAGE_MODEL_ORIGINAL),
|
||||
prompt_chars = payload
|
||||
.prompt_text
|
||||
.as_deref()
|
||||
@@ -508,6 +514,7 @@ pub async fn execute_puzzle_agent_action(
|
||||
owner_user_id.clone(),
|
||||
prompt_text,
|
||||
payload.reference_image_src.as_deref(),
|
||||
payload.image_model.as_deref(),
|
||||
now,
|
||||
)
|
||||
.await
|
||||
@@ -627,6 +634,7 @@ pub async fn execute_puzzle_agent_action(
|
||||
&target_level.level_name,
|
||||
&prompt,
|
||||
payload.reference_image_src.as_deref(),
|
||||
payload.image_model.as_deref(),
|
||||
candidate_count,
|
||||
candidate_start_index,
|
||||
)
|
||||
@@ -2406,6 +2414,7 @@ async fn compile_puzzle_draft_with_initial_cover(
|
||||
owner_user_id: String,
|
||||
prompt_text: Option<&str>,
|
||||
reference_image_src: Option<&str>,
|
||||
image_model: Option<&str>,
|
||||
now: i64,
|
||||
) -> Result<PuzzleAgentSessionRecord, SpacetimeClientError> {
|
||||
let compiled_session = state
|
||||
@@ -2431,6 +2440,7 @@ async fn compile_puzzle_draft_with_initial_cover(
|
||||
&target_level.level_name,
|
||||
&image_prompt,
|
||||
reference_image_src,
|
||||
image_model,
|
||||
1,
|
||||
target_level.candidates.len(),
|
||||
)
|
||||
@@ -2544,7 +2554,12 @@ fn is_missing_puzzle_form_draft_procedure_error(error: &SpacetimeClientError) ->
|
||||
|
||||
fn map_puzzle_compile_error(error: SpacetimeClientError) -> AppError {
|
||||
let message = error.to_string();
|
||||
let provider = if message.contains("DashScope") || message.contains("dashscope") {
|
||||
let provider = if message.contains("APIMart")
|
||||
|| message.contains("apimart")
|
||||
|| message.contains("APIMART")
|
||||
{
|
||||
"apimart"
|
||||
} else if message.contains("DashScope") || message.contains("dashscope") {
|
||||
"dashscope"
|
||||
} else if message.contains("OSS") || message.contains("oss") || message.contains("参考图") {
|
||||
"puzzle-assets"
|
||||
@@ -2556,6 +2571,9 @@ fn map_puzzle_compile_error(error: SpacetimeClientError) -> AppError {
|
||||
|| message.contains("上游")
|
||||
|| message.contains("DashScope")
|
||||
|| message.contains("dashscope")
|
||||
|| message.contains("APIMart")
|
||||
|| message.contains("apimart")
|
||||
|| message.contains("APIMART")
|
||||
|| message.contains("参考图")
|
||||
|| message.contains("图片")
|
||||
|| message.contains("OSS")
|
||||
@@ -2648,17 +2666,18 @@ async fn generate_puzzle_image_candidates(
|
||||
level_name: &str,
|
||||
prompt: &str,
|
||||
reference_image_src: Option<&str>,
|
||||
image_model: Option<&str>,
|
||||
candidate_count: u32,
|
||||
candidate_start_index: usize,
|
||||
) -> Result<Vec<PuzzleGeneratedImageCandidateRecord>, String> {
|
||||
let count = candidate_count.clamp(1, 1);
|
||||
let settings =
|
||||
require_puzzle_dashscope_settings(state).map_err(map_puzzle_generation_app_error)?;
|
||||
let http_client =
|
||||
build_puzzle_dashscope_http_client(&settings).map_err(map_puzzle_generation_app_error)?;
|
||||
let resolved_model = resolve_puzzle_image_model(image_model);
|
||||
let actual_prompt = build_puzzle_image_prompt(level_name, prompt);
|
||||
let http_client = build_puzzle_image_http_client(state, resolved_model)
|
||||
.map_err(map_puzzle_generation_app_error)?;
|
||||
tracing::info!(
|
||||
provider = "dashscope",
|
||||
provider = resolved_model.provider_name(),
|
||||
image_model = resolved_model.request_model_name(),
|
||||
session_id,
|
||||
level_name,
|
||||
prompt_chars = prompt.chars().count(),
|
||||
@@ -2680,29 +2699,50 @@ async fn generate_puzzle_image_candidates(
|
||||
),
|
||||
None => None,
|
||||
};
|
||||
// 中文注释:SpacetimeDB reducer 不能做外部 I/O,参考图读取与 DashScope 图生图都必须停留在 api-server。
|
||||
// 中文注释:SpacetimeDB reducer 不能做外部 I/O,参考图读取与外部生图都必须停留在 api-server。
|
||||
// 中文注释:拼图作品资产统一按 1:1 正方形生成,前端运行时也按正方形棋盘切块承载。
|
||||
let generated = match reference_image.as_deref() {
|
||||
Some(reference_image) => {
|
||||
create_puzzle_image_to_image_generation(
|
||||
&http_client,
|
||||
&settings,
|
||||
actual_prompt.as_str(),
|
||||
PUZZLE_DEFAULT_NEGATIVE_PROMPT,
|
||||
PUZZLE_GENERATED_IMAGE_SIZE,
|
||||
count,
|
||||
reference_image,
|
||||
)
|
||||
.await
|
||||
let generated = match resolved_model {
|
||||
PuzzleImageModel::Original => {
|
||||
let settings = require_puzzle_dashscope_settings(state)
|
||||
.map_err(map_puzzle_generation_app_error)?;
|
||||
match reference_image.as_deref() {
|
||||
Some(reference_image) => {
|
||||
create_puzzle_image_to_image_generation(
|
||||
&http_client,
|
||||
&settings,
|
||||
actual_prompt.as_str(),
|
||||
PUZZLE_DEFAULT_NEGATIVE_PROMPT,
|
||||
PUZZLE_GENERATED_IMAGE_SIZE,
|
||||
count,
|
||||
reference_image,
|
||||
)
|
||||
.await
|
||||
}
|
||||
None => {
|
||||
create_puzzle_text_to_image_generation(
|
||||
&http_client,
|
||||
&settings,
|
||||
actual_prompt.as_str(),
|
||||
PUZZLE_DEFAULT_NEGATIVE_PROMPT,
|
||||
PUZZLE_GENERATED_IMAGE_SIZE,
|
||||
count,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
create_puzzle_text_to_image_generation(
|
||||
PuzzleImageModel::GptImage2 | PuzzleImageModel::Gemini31FlashPreview => {
|
||||
let settings =
|
||||
require_puzzle_apimart_settings(state).map_err(map_puzzle_generation_app_error)?;
|
||||
create_puzzle_apimart_image_generation(
|
||||
&http_client,
|
||||
&settings,
|
||||
resolved_model,
|
||||
actual_prompt.as_str(),
|
||||
PUZZLE_DEFAULT_NEGATIVE_PROMPT,
|
||||
PUZZLE_GENERATED_IMAGE_SIZE,
|
||||
PUZZLE_APIMART_GENERATED_IMAGE_SIZE,
|
||||
count,
|
||||
reference_image.as_deref(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -2733,7 +2773,7 @@ async fn generate_puzzle_image_candidates(
|
||||
asset_id: asset.asset_id,
|
||||
prompt: prompt.to_string(),
|
||||
actual_prompt: Some(actual_prompt.clone()),
|
||||
source_type: "generated".to_string(),
|
||||
source_type: resolved_model.candidate_source_type().to_string(),
|
||||
// 单图生成结果总是直接成为当前正式图。
|
||||
selected: index == 0,
|
||||
});
|
||||
@@ -2823,6 +2863,7 @@ async fn build_local_next_puzzle_run(
|
||||
&draft.level_name,
|
||||
&draft.summary,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
draft.candidates.len(),
|
||||
)
|
||||
@@ -3577,6 +3618,7 @@ mod tests {
|
||||
#[test]
|
||||
fn puzzle_generated_image_size_is_square_1_1() {
|
||||
assert_eq!(PUZZLE_GENERATED_IMAGE_SIZE, "1024*1024");
|
||||
assert_eq!(PUZZLE_APIMART_GENERATED_IMAGE_SIZE, "1:1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -3598,6 +3640,30 @@ mod tests {
|
||||
assert_eq!(body["parameters"]["n"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn puzzle_apimart_request_uses_selected_model_and_reference_images() {
|
||||
let body = build_puzzle_apimart_image_request_body(
|
||||
PuzzleImageModel::Gemini31FlashPreview,
|
||||
"一只猫在雨夜灯牌下回头。",
|
||||
PUZZLE_DEFAULT_NEGATIVE_PROMPT,
|
||||
PUZZLE_APIMART_GENERATED_IMAGE_SIZE,
|
||||
4,
|
||||
Some("data:image/png;base64,abcd"),
|
||||
);
|
||||
|
||||
assert_eq!(body["model"], PUZZLE_IMAGE_MODEL_GEMINI_31_FLASH_PREVIEW);
|
||||
assert_eq!(body["size"], PUZZLE_APIMART_GENERATED_IMAGE_SIZE);
|
||||
assert_eq!(body["resolution"], PUZZLE_APIMART_GEMINI_RESOLUTION);
|
||||
assert_eq!(body["n"], 1);
|
||||
assert_eq!(body["image_urls"][0], "data:image/png;base64,abcd");
|
||||
assert!(
|
||||
body["prompt"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.contains("文字水印")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn puzzle_dashscope_upstream_error_keeps_status_and_raw_excerpt() {
|
||||
let error = map_puzzle_dashscope_upstream_error(
|
||||
@@ -3639,6 +3705,44 @@ struct PuzzleDashScopeSettings {
|
||||
request_timeout_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
enum PuzzleImageModel {
|
||||
Original,
|
||||
GptImage2,
|
||||
Gemini31FlashPreview,
|
||||
}
|
||||
|
||||
impl PuzzleImageModel {
|
||||
fn provider_name(self) -> &'static str {
|
||||
match self {
|
||||
Self::Original => "dashscope",
|
||||
Self::GptImage2 | Self::Gemini31FlashPreview => "apimart",
|
||||
}
|
||||
}
|
||||
|
||||
fn request_model_name(self) -> &'static str {
|
||||
match self {
|
||||
Self::Original => PUZZLE_TEXT_TO_IMAGE_MODEL,
|
||||
Self::GptImage2 => PUZZLE_IMAGE_MODEL_GPT_IMAGE_2,
|
||||
Self::Gemini31FlashPreview => PUZZLE_IMAGE_MODEL_GEMINI_31_FLASH_PREVIEW,
|
||||
}
|
||||
}
|
||||
|
||||
fn candidate_source_type(self) -> &'static str {
|
||||
match self {
|
||||
Self::Original => "generated",
|
||||
Self::GptImage2 => "generated:gpt-image-2",
|
||||
Self::Gemini31FlashPreview => "generated:nanobanana2",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PuzzleApimartSettings {
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
request_timeout_ms: u64,
|
||||
}
|
||||
|
||||
struct PuzzleGeneratedImages {
|
||||
task_id: String,
|
||||
images: Vec<PuzzleDownloadedImage>,
|
||||
@@ -3694,16 +3798,69 @@ fn require_puzzle_dashscope_settings(
|
||||
})
|
||||
}
|
||||
|
||||
fn build_puzzle_dashscope_http_client(
|
||||
settings: &PuzzleDashScopeSettings,
|
||||
fn resolve_puzzle_image_model(value: Option<&str>) -> PuzzleImageModel {
|
||||
match value
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or(PUZZLE_IMAGE_MODEL_ORIGINAL)
|
||||
{
|
||||
PUZZLE_IMAGE_MODEL_GPT_IMAGE_2 => PuzzleImageModel::GptImage2,
|
||||
PUZZLE_IMAGE_MODEL_GEMINI_31_FLASH_PREVIEW => PuzzleImageModel::Gemini31FlashPreview,
|
||||
_ => PuzzleImageModel::Original,
|
||||
}
|
||||
}
|
||||
|
||||
fn require_puzzle_apimart_settings(state: &AppState) -> Result<PuzzleApimartSettings, AppError> {
|
||||
let base_url = state.config.apimart_base_url.trim().trim_end_matches('/');
|
||||
if base_url.is_empty() {
|
||||
return Err(
|
||||
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
|
||||
"provider": "apimart",
|
||||
"reason": "APIMART_BASE_URL 未配置",
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
let api_key = state
|
||||
.config
|
||||
.apimart_api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.ok_or_else(|| {
|
||||
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
|
||||
"provider": "apimart",
|
||||
"reason": "APIMART_API_KEY 未配置",
|
||||
}))
|
||||
})?;
|
||||
|
||||
Ok(PuzzleApimartSettings {
|
||||
base_url: base_url.to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
request_timeout_ms: state.config.apimart_image_request_timeout_ms.max(1),
|
||||
})
|
||||
}
|
||||
|
||||
fn build_puzzle_image_http_client(
|
||||
state: &AppState,
|
||||
image_model: PuzzleImageModel,
|
||||
) -> Result<reqwest::Client, AppError> {
|
||||
let (provider, request_timeout_ms) = match image_model {
|
||||
PuzzleImageModel::Original => {
|
||||
("dashscope", state.config.dashscope_image_request_timeout_ms)
|
||||
}
|
||||
PuzzleImageModel::GptImage2 | PuzzleImageModel::Gemini31FlashPreview => {
|
||||
("apimart", state.config.apimart_image_request_timeout_ms)
|
||||
}
|
||||
};
|
||||
|
||||
reqwest::Client::builder()
|
||||
.timeout(Duration::from_millis(settings.request_timeout_ms))
|
||||
.timeout(Duration::from_millis(request_timeout_ms.max(1)))
|
||||
.build()
|
||||
.map_err(|error| {
|
||||
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_details(json!({
|
||||
"provider": "dashscope",
|
||||
"message": format!("构造拼图 DashScope HTTP 客户端失败:{error}"),
|
||||
"provider": provider,
|
||||
"message": format!("构造拼图图片生成 HTTP 客户端失败:{error}"),
|
||||
}))
|
||||
})
|
||||
}
|
||||
@@ -3866,6 +4023,229 @@ fn build_puzzle_text_to_image_request_body(
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_puzzle_apimart_image_generation(
|
||||
http_client: &reqwest::Client,
|
||||
settings: &PuzzleApimartSettings,
|
||||
image_model: PuzzleImageModel,
|
||||
prompt: &str,
|
||||
negative_prompt: &str,
|
||||
size: &str,
|
||||
candidate_count: u32,
|
||||
reference_image: Option<&str>,
|
||||
) -> Result<PuzzleGeneratedImages, AppError> {
|
||||
let request_body = build_puzzle_apimart_image_request_body(
|
||||
image_model,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
size,
|
||||
candidate_count,
|
||||
reference_image,
|
||||
);
|
||||
let response = http_client
|
||||
.post(format!("{}/images/generations", settings.base_url))
|
||||
.header(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", settings.api_key),
|
||||
)
|
||||
.header(reqwest::header::CONTENT_TYPE, "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|error| {
|
||||
map_puzzle_apimart_request_error(format!("创建拼图 APIMart 图片生成任务失败:{error}"))
|
||||
})?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await.map_err(|error| {
|
||||
map_puzzle_apimart_request_error(format!("读取拼图 APIMart 图片生成响应失败:{error}"))
|
||||
})?;
|
||||
if !status.is_success() {
|
||||
return Err(map_puzzle_apimart_upstream_error(
|
||||
status,
|
||||
response_text.as_str(),
|
||||
"创建拼图 APIMart 图片生成任务失败",
|
||||
));
|
||||
}
|
||||
|
||||
let payload =
|
||||
parse_puzzle_json_payload(response_text.as_str(), "解析拼图 APIMart 图片生成响应失败")?;
|
||||
let image_urls = extract_puzzle_image_urls(&payload);
|
||||
if !image_urls.is_empty() {
|
||||
return download_puzzle_images_from_urls(
|
||||
http_client,
|
||||
format!("apimart-{}", current_utc_micros()),
|
||||
image_urls,
|
||||
candidate_count,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let task_id = extract_puzzle_task_id(&payload).ok_or_else(|| {
|
||||
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
|
||||
"provider": "apimart",
|
||||
"message": "拼图 APIMart 图片生成未返回 task_id 或图片地址",
|
||||
}))
|
||||
})?;
|
||||
|
||||
wait_puzzle_apimart_generated_images(
|
||||
http_client,
|
||||
settings,
|
||||
task_id.as_str(),
|
||||
candidate_count,
|
||||
"拼图 APIMart 图片生成任务失败",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn build_puzzle_apimart_image_request_body(
|
||||
image_model: PuzzleImageModel,
|
||||
prompt: &str,
|
||||
negative_prompt: &str,
|
||||
size: &str,
|
||||
candidate_count: u32,
|
||||
reference_image: Option<&str>,
|
||||
) -> Value {
|
||||
let mut body = Map::from_iter([
|
||||
(
|
||||
"model".to_string(),
|
||||
Value::String(image_model.request_model_name().to_string()),
|
||||
),
|
||||
(
|
||||
"prompt".to_string(),
|
||||
Value::String(build_puzzle_apimart_prompt(prompt, negative_prompt)),
|
||||
),
|
||||
("n".to_string(), json!(candidate_count.clamp(1, 1))),
|
||||
("size".to_string(), Value::String(size.to_string())),
|
||||
]);
|
||||
body.insert(
|
||||
"resolution".to_string(),
|
||||
Value::String(
|
||||
match image_model {
|
||||
PuzzleImageModel::Gemini31FlashPreview => PUZZLE_APIMART_GEMINI_RESOLUTION,
|
||||
_ => "1k",
|
||||
}
|
||||
.to_string(),
|
||||
),
|
||||
);
|
||||
|
||||
if let Some(reference_image) = reference_image
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
body.insert("image_urls".to_string(), json!([reference_image]));
|
||||
}
|
||||
|
||||
Value::Object(body)
|
||||
}
|
||||
|
||||
fn build_puzzle_apimart_prompt(prompt: &str, negative_prompt: &str) -> String {
|
||||
let prompt = prompt.trim();
|
||||
let negative_prompt = negative_prompt.trim();
|
||||
if negative_prompt.is_empty() {
|
||||
return prompt.to_string();
|
||||
}
|
||||
|
||||
format!("{prompt}\n避免:{negative_prompt}")
|
||||
}
|
||||
|
||||
async fn wait_puzzle_apimart_generated_images(
|
||||
http_client: &reqwest::Client,
|
||||
settings: &PuzzleApimartSettings,
|
||||
task_id: &str,
|
||||
candidate_count: u32,
|
||||
failure_message: &str,
|
||||
) -> Result<PuzzleGeneratedImages, AppError> {
|
||||
let deadline = Instant::now() + Duration::from_millis(settings.request_timeout_ms);
|
||||
|
||||
while Instant::now() < deadline {
|
||||
let poll_response = http_client
|
||||
.get(format!("{}/tasks/{}", settings.base_url, task_id))
|
||||
.header(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", settings.api_key),
|
||||
)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|error| {
|
||||
map_puzzle_apimart_request_error(format!(
|
||||
"查询拼图 APIMart 图片生成任务失败:{error}"
|
||||
))
|
||||
})?;
|
||||
let poll_status = poll_response.status();
|
||||
let poll_text = poll_response.text().await.map_err(|error| {
|
||||
map_puzzle_apimart_request_error(format!(
|
||||
"读取拼图 APIMart 图片生成任务响应失败:{error}"
|
||||
))
|
||||
})?;
|
||||
if !poll_status.is_success() {
|
||||
return Err(map_puzzle_apimart_upstream_error(
|
||||
poll_status,
|
||||
poll_text.as_str(),
|
||||
"查询拼图 APIMart 图片生成任务失败",
|
||||
));
|
||||
}
|
||||
|
||||
let poll_payload =
|
||||
parse_puzzle_json_payload(poll_text.as_str(), "解析拼图 APIMart 图片生成任务响应失败")?;
|
||||
let task_status = find_first_puzzle_string_by_key(&poll_payload, "status")
|
||||
.unwrap_or_default()
|
||||
.trim()
|
||||
.to_ascii_lowercase();
|
||||
if matches!(task_status.as_str(), "completed" | "succeeded" | "success") {
|
||||
let image_urls = extract_puzzle_image_urls(&poll_payload);
|
||||
if image_urls.is_empty() {
|
||||
return Err(
|
||||
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
|
||||
"provider": "apimart",
|
||||
"message": "拼图 APIMart 图片生成成功但未返回图片地址",
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
return download_puzzle_images_from_urls(
|
||||
http_client,
|
||||
task_id.to_string(),
|
||||
image_urls,
|
||||
candidate_count,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
if matches!(
|
||||
task_status.as_str(),
|
||||
"failed" | "error" | "canceled" | "cancelled"
|
||||
) {
|
||||
return Err(map_puzzle_apimart_upstream_error(
|
||||
poll_status,
|
||||
poll_text.as_str(),
|
||||
failure_message,
|
||||
));
|
||||
}
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
}
|
||||
|
||||
Err(
|
||||
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
|
||||
"provider": "apimart",
|
||||
"message": "拼图 APIMart 图片生成超时或未返回图片地址",
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
async fn download_puzzle_images_from_urls(
|
||||
http_client: &reqwest::Client,
|
||||
task_id: String,
|
||||
image_urls: Vec<String>,
|
||||
candidate_count: u32,
|
||||
) -> Result<PuzzleGeneratedImages, AppError> {
|
||||
let mut images = Vec::with_capacity(candidate_count.clamp(1, 1) as usize);
|
||||
for image_url in image_urls
|
||||
.into_iter()
|
||||
.take(candidate_count.clamp(1, 1) as usize)
|
||||
{
|
||||
images.push(download_puzzle_remote_image(http_client, image_url.as_str()).await?);
|
||||
}
|
||||
Ok(PuzzleGeneratedImages { task_id, images })
|
||||
}
|
||||
|
||||
async fn resolve_puzzle_reference_image_as_data_url(
|
||||
state: &AppState,
|
||||
http_client: &reqwest::Client,
|
||||
@@ -4427,6 +4807,36 @@ fn map_puzzle_dashscope_upstream_error(
|
||||
}))
|
||||
}
|
||||
|
||||
fn map_puzzle_apimart_request_error(message: String) -> AppError {
|
||||
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
|
||||
"provider": "apimart",
|
||||
"message": message,
|
||||
}))
|
||||
}
|
||||
|
||||
fn map_puzzle_apimart_upstream_error(
|
||||
upstream_status: reqwest::StatusCode,
|
||||
raw_text: &str,
|
||||
fallback_message: &str,
|
||||
) -> AppError {
|
||||
let message = parse_puzzle_api_error_message(raw_text, fallback_message);
|
||||
let raw_excerpt = trim_puzzle_upstream_excerpt(raw_text, 800);
|
||||
tracing::warn!(
|
||||
provider = "apimart",
|
||||
upstream_status = upstream_status.as_u16(),
|
||||
message = %message,
|
||||
raw_excerpt = %raw_excerpt,
|
||||
"拼图 APIMart 上游请求失败"
|
||||
);
|
||||
|
||||
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
|
||||
"provider": "apimart",
|
||||
"upstreamStatus": upstream_status.as_u16(),
|
||||
"message": message,
|
||||
"rawExcerpt": raw_excerpt,
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_puzzle_api_error_message(raw_text: &str, fallback_message: &str) -> String {
|
||||
let trimmed = raw_text.trim();
|
||||
if trimmed.is_empty() {
|
||||
|
||||
Reference in New Issue
Block a user