338 lines
12 KiB
Rust
338 lines
12 KiB
Rust
use axum::{
|
||
Json,
|
||
extract::{Extension, State, rejection::JsonRejection},
|
||
http::StatusCode,
|
||
response::Response,
|
||
};
|
||
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
|
||
use image::{ColorType, ImageEncoder, codecs::png::PngEncoder};
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::{Value, json};
|
||
|
||
use crate::{
|
||
api_response::json_success_body,
|
||
http_error::AppError,
|
||
openai_image_generation::{
|
||
DownloadedOpenAiImage, build_openai_image_http_client, create_openai_image_generation,
|
||
require_openai_image_settings,
|
||
},
|
||
request_context::RequestContext,
|
||
state::AppState,
|
||
};
|
||
|
||
const BABY_LOVE_DRAWING_PROVIDER: &str = "vector-engine-gpt-image-2";
|
||
const BABY_LOVE_DRAWING_IMAGE_SIZE: &str = "1024x1024";
|
||
const BABY_LOVE_DRAWING_MAX_STROKES: usize = 600;
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub(crate) struct CreateBabyLoveDrawingMagicRequest {
|
||
original_image_src: String,
|
||
#[serde(default)]
|
||
stroke_trace: Vec<BabyLoveDrawingStrokePayload>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
struct BabyLoveDrawingStrokePayload {
|
||
stroke_id: String,
|
||
tool: String,
|
||
color: String,
|
||
#[serde(default)]
|
||
points: Vec<BabyLoveDrawingPointPayload>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct BabyLoveDrawingPointPayload {
|
||
x: f64,
|
||
y: f64,
|
||
t: f64,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
struct CreateBabyLoveDrawingMagicResponse {
|
||
magic_image_src: String,
|
||
generation_provider: String,
|
||
prompt: String,
|
||
}
|
||
|
||
pub async fn create_baby_love_drawing_magic(
|
||
State(state): State<AppState>,
|
||
Extension(request_context): Extension<RequestContext>,
|
||
payload: Result<Json<CreateBabyLoveDrawingMagicRequest>, JsonRejection>,
|
||
) -> Result<Json<Value>, Response> {
|
||
let Json(payload) = payload.map_err(|error| {
|
||
baby_love_drawing_error_response(
|
||
&request_context,
|
||
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
||
"provider": "edutainment-baby-drawing",
|
||
"message": error.body_text(),
|
||
})),
|
||
)
|
||
})?;
|
||
validate_magic_request(&payload)
|
||
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
|
||
|
||
let settings = require_openai_image_settings(&state)
|
||
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
|
||
let http_client = build_openai_image_http_client(&settings)
|
||
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
|
||
let prompt = build_baby_love_drawing_magic_prompt(payload.stroke_trace.as_slice());
|
||
let reference_images = vec![payload.original_image_src.trim().to_string()];
|
||
let generated = create_openai_image_generation(
|
||
&http_client,
|
||
&settings,
|
||
prompt.as_str(),
|
||
Some(build_baby_love_drawing_negative_prompt()),
|
||
BABY_LOVE_DRAWING_IMAGE_SIZE,
|
||
1,
|
||
reference_images.as_slice(),
|
||
"宝贝爱画绘画魔法图片生成失败",
|
||
)
|
||
.await
|
||
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
|
||
let generated_image = generated.images.into_iter().next().ok_or_else(|| {
|
||
baby_love_drawing_error_response(
|
||
&request_context,
|
||
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
|
||
"provider": "vector-engine",
|
||
"message": "宝贝爱画绘画魔法没有返回图片。",
|
||
})),
|
||
)
|
||
})?;
|
||
let magic_image_src = build_png_data_url(generated_image)
|
||
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
|
||
|
||
Ok(json_success_body(
|
||
Some(&request_context),
|
||
CreateBabyLoveDrawingMagicResponse {
|
||
magic_image_src,
|
||
generation_provider: BABY_LOVE_DRAWING_PROVIDER.to_string(),
|
||
prompt,
|
||
},
|
||
))
|
||
}
|
||
|
||
fn validate_magic_request(payload: &CreateBabyLoveDrawingMagicRequest) -> Result<(), AppError> {
|
||
let original_image_src = payload.original_image_src.trim();
|
||
if !original_image_src.starts_with("data:image/") || !original_image_src.contains(";base64,") {
|
||
return Err(
|
||
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
||
"provider": "edutainment-baby-drawing",
|
||
"message": "绘画原图必须是图片 Data URL。",
|
||
})),
|
||
);
|
||
}
|
||
|
||
if payload.stroke_trace.len() > BABY_LOVE_DRAWING_MAX_STROKES {
|
||
return Err(
|
||
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
||
"provider": "edutainment-baby-drawing",
|
||
"message": "绘画笔触数量过多,请重新完成绘画后再使用魔法。",
|
||
})),
|
||
);
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn build_baby_love_drawing_magic_prompt(stroke_trace: &[BabyLoveDrawingStrokePayload]) -> String {
|
||
let stroke_count = stroke_trace.len();
|
||
let brush_count = stroke_trace
|
||
.iter()
|
||
.filter(|stroke| stroke.tool.trim() == "brush")
|
||
.count();
|
||
let eraser_count = stroke_trace
|
||
.iter()
|
||
.filter(|stroke| stroke.tool.trim() == "eraser")
|
||
.count();
|
||
let color_summary = summarize_stroke_colors(stroke_trace);
|
||
let trace_bounds = summarize_trace_bounds(stroke_trace);
|
||
|
||
format!(
|
||
"根据参考图中的儿童绘画内容,为寓教于乐独立关卡“宝贝爱画”生成一张绘本风格图片。\n\
|
||
必须保留小朋友原始画面的主体构图、线条方向、颜色关系和童趣笔触,不要改成与原图无关的新内容。\n\
|
||
输出风格:明亮、温暖、柔和、卡通绘本风格,适合 4-8 岁儿童,画面干净,边缘柔和,有轻微纸面质感。\n\
|
||
笔触信息:总笔触 {stroke_count} 条,画笔 {brush_count} 条,橡皮 {eraser_count} 条,主要颜色 {color_summary},绘制范围 {trace_bounds}。\n\
|
||
不要生成文字、水印、Logo、按钮、UI 面板、真实照片风、恐怖或成人化内容。"
|
||
)
|
||
}
|
||
|
||
fn summarize_stroke_colors(stroke_trace: &[BabyLoveDrawingStrokePayload]) -> String {
|
||
let mut colors = Vec::new();
|
||
for stroke in stroke_trace {
|
||
if stroke.stroke_id.trim().is_empty() {
|
||
continue;
|
||
}
|
||
let color = stroke.color.trim();
|
||
if color.is_empty() || colors.iter().any(|value| value == color) {
|
||
continue;
|
||
}
|
||
colors.push(color.to_string());
|
||
if colors.len() >= 5 {
|
||
break;
|
||
}
|
||
}
|
||
|
||
if colors.is_empty() {
|
||
"无明显颜色记录".to_string()
|
||
} else {
|
||
colors.join("、")
|
||
}
|
||
}
|
||
|
||
fn summarize_trace_bounds(stroke_trace: &[BabyLoveDrawingStrokePayload]) -> String {
|
||
let mut min_x = 1.0_f64;
|
||
let mut min_y = 1.0_f64;
|
||
let mut max_x = 0.0_f64;
|
||
let mut max_y = 0.0_f64;
|
||
let mut has_point = false;
|
||
|
||
for point in stroke_trace.iter().flat_map(|stroke| stroke.points.iter()) {
|
||
if !(point.x.is_finite() && point.y.is_finite() && point.t.is_finite()) {
|
||
continue;
|
||
}
|
||
has_point = true;
|
||
min_x = min_x.min(point.x.clamp(0.0, 1.0));
|
||
min_y = min_y.min(point.y.clamp(0.0, 1.0));
|
||
max_x = max_x.max(point.x.clamp(0.0, 1.0));
|
||
max_y = max_y.max(point.y.clamp(0.0, 1.0));
|
||
}
|
||
|
||
if !has_point {
|
||
return "无可用坐标记录".to_string();
|
||
}
|
||
|
||
format!("x {:.2}-{:.2}, y {:.2}-{:.2}", min_x, max_x, min_y, max_y)
|
||
}
|
||
|
||
fn build_baby_love_drawing_negative_prompt() -> &'static str {
|
||
"文字,水印,Logo,按钮,UI,面板,复杂背景,真实照片风,恐怖元素,成人化内容,攻击性内容,替换原图主体,完全无关的新画面"
|
||
}
|
||
|
||
fn build_png_data_url(image: DownloadedOpenAiImage) -> Result<String, AppError> {
|
||
let png_bytes = normalize_generated_image_to_png(image.bytes.as_slice())?;
|
||
Ok(format!(
|
||
"data:image/png;base64,{}",
|
||
BASE64_STANDARD.encode(png_bytes)
|
||
))
|
||
}
|
||
|
||
fn normalize_generated_image_to_png(source: &[u8]) -> Result<Vec<u8>, AppError> {
|
||
let rgba_image = image::load_from_memory(source)
|
||
.map_err(|error| {
|
||
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
|
||
"provider": "vector-engine",
|
||
"message": format!("解析宝贝爱画魔法图片失败:{error}"),
|
||
}))
|
||
})?
|
||
.to_rgba8();
|
||
let (width, height) = rgba_image.dimensions();
|
||
let mut encoded = Vec::new();
|
||
let encoder = PngEncoder::new(&mut encoded);
|
||
encoder
|
||
.write_image(rgba_image.as_raw(), width, height, ColorType::Rgba8.into())
|
||
.map_err(|error| {
|
||
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
|
||
"provider": "vector-engine",
|
||
"message": format!("转换宝贝爱画魔法图片为 PNG 失败:{error}"),
|
||
}))
|
||
})?;
|
||
|
||
Ok(encoded)
|
||
}
|
||
|
||
fn baby_love_drawing_error_response(request_context: &RequestContext, error: AppError) -> Response {
|
||
error.into_response_with_context(Some(request_context))
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
fn sample_request() -> CreateBabyLoveDrawingMagicRequest {
|
||
CreateBabyLoveDrawingMagicRequest {
|
||
original_image_src: "data:image/png;base64,abcd".to_string(),
|
||
stroke_trace: vec![BabyLoveDrawingStrokePayload {
|
||
stroke_id: "stroke-1".to_string(),
|
||
tool: "brush".to_string(),
|
||
color: "#ef4444".to_string(),
|
||
points: vec![
|
||
BabyLoveDrawingPointPayload {
|
||
x: 0.2,
|
||
y: 0.3,
|
||
t: 1.0,
|
||
},
|
||
BabyLoveDrawingPointPayload {
|
||
x: 0.7,
|
||
y: 0.8,
|
||
t: 2.0,
|
||
},
|
||
],
|
||
}],
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn magic_prompt_keeps_child_drawing_and_picture_book_style() {
|
||
let request = sample_request();
|
||
let prompt = build_baby_love_drawing_magic_prompt(request.stroke_trace.as_slice());
|
||
|
||
assert!(prompt.contains("宝贝爱画"));
|
||
assert!(prompt.contains("绘本风格"));
|
||
assert!(prompt.contains("保留小朋友原始画面"));
|
||
assert!(prompt.contains("#ef4444"));
|
||
assert!(prompt.contains("x 0.20-0.70"));
|
||
}
|
||
|
||
#[test]
|
||
fn magic_request_requires_image_data_url() {
|
||
let request = sample_request();
|
||
assert!(validate_magic_request(&request).is_ok());
|
||
|
||
let invalid = CreateBabyLoveDrawingMagicRequest {
|
||
original_image_src: "https://example.test/image.png".to_string(),
|
||
..sample_request()
|
||
};
|
||
|
||
assert!(validate_magic_request(&invalid).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn normalizes_png_to_png_data_url() {
|
||
let mut source = Vec::new();
|
||
let pixels = vec![255u8; 4 * 2 * 2];
|
||
let encoder = PngEncoder::new(&mut source);
|
||
encoder
|
||
.write_image(pixels.as_slice(), 2, 2, ColorType::Rgba8.into())
|
||
.expect("test png should encode");
|
||
|
||
let image_src = build_png_data_url(DownloadedOpenAiImage {
|
||
bytes: source,
|
||
mime_type: "image/png".to_string(),
|
||
extension: "png".to_string(),
|
||
})
|
||
.expect("test png should normalize");
|
||
|
||
assert!(image_src.starts_with("data:image/png;base64,"));
|
||
}
|
||
|
||
#[test]
|
||
fn trace_summary_ignores_invalid_points() {
|
||
let mut request = sample_request();
|
||
request.stroke_trace[0]
|
||
.points
|
||
.push(BabyLoveDrawingPointPayload {
|
||
x: f64::NAN,
|
||
y: 0.1,
|
||
t: 3.0,
|
||
});
|
||
|
||
assert_eq!(
|
||
summarize_trace_bounds(request.stroke_trace.as_slice()),
|
||
"x 0.20-0.70, y 0.30-0.80",
|
||
);
|
||
}
|
||
}
|