Files
Genarrative/server-rs/crates/api-server/src/edutainment_baby_drawing.rs

338 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
Json,
extract::{Extension, State, rejection::JsonRejection},
http::StatusCode,
response::Response,
};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use image::{ColorType, ImageEncoder, codecs::png::PngEncoder};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use crate::{
api_response::json_success_body,
http_error::AppError,
openai_image_generation::{
DownloadedOpenAiImage, build_openai_image_http_client, create_openai_image_generation,
require_openai_image_settings,
},
request_context::RequestContext,
state::AppState,
};
const BABY_LOVE_DRAWING_PROVIDER: &str = "vector-engine-gpt-image-2";
const BABY_LOVE_DRAWING_IMAGE_SIZE: &str = "1024x1024";
const BABY_LOVE_DRAWING_MAX_STROKES: usize = 600;
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct CreateBabyLoveDrawingMagicRequest {
original_image_src: String,
#[serde(default)]
stroke_trace: Vec<BabyLoveDrawingStrokePayload>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct BabyLoveDrawingStrokePayload {
stroke_id: String,
tool: String,
color: String,
#[serde(default)]
points: Vec<BabyLoveDrawingPointPayload>,
}
#[derive(Debug, Deserialize)]
struct BabyLoveDrawingPointPayload {
x: f64,
y: f64,
t: f64,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateBabyLoveDrawingMagicResponse {
magic_image_src: String,
generation_provider: String,
prompt: String,
}
pub async fn create_baby_love_drawing_magic(
State(state): State<AppState>,
Extension(request_context): Extension<RequestContext>,
payload: Result<Json<CreateBabyLoveDrawingMagicRequest>, JsonRejection>,
) -> Result<Json<Value>, Response> {
let Json(payload) = payload.map_err(|error| {
baby_love_drawing_error_response(
&request_context,
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "edutainment-baby-drawing",
"message": error.body_text(),
})),
)
})?;
validate_magic_request(&payload)
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
let settings = require_openai_image_settings(&state)
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
let http_client = build_openai_image_http_client(&settings)
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
let prompt = build_baby_love_drawing_magic_prompt(payload.stroke_trace.as_slice());
let reference_images = vec![payload.original_image_src.trim().to_string()];
let generated = create_openai_image_generation(
&http_client,
&settings,
prompt.as_str(),
Some(build_baby_love_drawing_negative_prompt()),
BABY_LOVE_DRAWING_IMAGE_SIZE,
1,
reference_images.as_slice(),
"宝贝爱画绘画魔法图片生成失败",
)
.await
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
let generated_image = generated.images.into_iter().next().ok_or_else(|| {
baby_love_drawing_error_response(
&request_context,
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "vector-engine",
"message": "宝贝爱画绘画魔法没有返回图片。",
})),
)
})?;
let magic_image_src = build_png_data_url(generated_image)
.map_err(|error| baby_love_drawing_error_response(&request_context, error))?;
Ok(json_success_body(
Some(&request_context),
CreateBabyLoveDrawingMagicResponse {
magic_image_src,
generation_provider: BABY_LOVE_DRAWING_PROVIDER.to_string(),
prompt,
},
))
}
fn validate_magic_request(payload: &CreateBabyLoveDrawingMagicRequest) -> Result<(), AppError> {
let original_image_src = payload.original_image_src.trim();
if !original_image_src.starts_with("data:image/") || !original_image_src.contains(";base64,") {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "edutainment-baby-drawing",
"message": "绘画原图必须是图片 Data URL。",
})),
);
}
if payload.stroke_trace.len() > BABY_LOVE_DRAWING_MAX_STROKES {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "edutainment-baby-drawing",
"message": "绘画笔触数量过多,请重新完成绘画后再使用魔法。",
})),
);
}
Ok(())
}
fn build_baby_love_drawing_magic_prompt(stroke_trace: &[BabyLoveDrawingStrokePayload]) -> String {
let stroke_count = stroke_trace.len();
let brush_count = stroke_trace
.iter()
.filter(|stroke| stroke.tool.trim() == "brush")
.count();
let eraser_count = stroke_trace
.iter()
.filter(|stroke| stroke.tool.trim() == "eraser")
.count();
let color_summary = summarize_stroke_colors(stroke_trace);
let trace_bounds = summarize_trace_bounds(stroke_trace);
format!(
"根据参考图中的儿童绘画内容,为寓教于乐独立关卡“宝贝爱画”生成一张绘本风格图片。\n\
必须保留小朋友原始画面的主体构图、线条方向、颜色关系和童趣笔触,不要改成与原图无关的新内容。\n\
输出风格:明亮、温暖、柔和、卡通绘本风格,适合 4-8 岁儿童,画面干净,边缘柔和,有轻微纸面质感。\n\
笔触信息:总笔触 {stroke_count} 条,画笔 {brush_count} 条,橡皮 {eraser_count} 条,主要颜色 {color_summary},绘制范围 {trace_bounds}\n\
不要生成文字、水印、Logo、按钮、UI 面板、真实照片风、恐怖或成人化内容。"
)
}
fn summarize_stroke_colors(stroke_trace: &[BabyLoveDrawingStrokePayload]) -> String {
let mut colors = Vec::new();
for stroke in stroke_trace {
if stroke.stroke_id.trim().is_empty() {
continue;
}
let color = stroke.color.trim();
if color.is_empty() || colors.iter().any(|value| value == color) {
continue;
}
colors.push(color.to_string());
if colors.len() >= 5 {
break;
}
}
if colors.is_empty() {
"无明显颜色记录".to_string()
} else {
colors.join("")
}
}
fn summarize_trace_bounds(stroke_trace: &[BabyLoveDrawingStrokePayload]) -> String {
let mut min_x = 1.0_f64;
let mut min_y = 1.0_f64;
let mut max_x = 0.0_f64;
let mut max_y = 0.0_f64;
let mut has_point = false;
for point in stroke_trace.iter().flat_map(|stroke| stroke.points.iter()) {
if !(point.x.is_finite() && point.y.is_finite() && point.t.is_finite()) {
continue;
}
has_point = true;
min_x = min_x.min(point.x.clamp(0.0, 1.0));
min_y = min_y.min(point.y.clamp(0.0, 1.0));
max_x = max_x.max(point.x.clamp(0.0, 1.0));
max_y = max_y.max(point.y.clamp(0.0, 1.0));
}
if !has_point {
return "无可用坐标记录".to_string();
}
format!("x {:.2}-{:.2}, y {:.2}-{:.2}", min_x, max_x, min_y, max_y)
}
fn build_baby_love_drawing_negative_prompt() -> &'static str {
"文字水印Logo按钮UI面板复杂背景真实照片风恐怖元素成人化内容攻击性内容替换原图主体完全无关的新画面"
}
fn build_png_data_url(image: DownloadedOpenAiImage) -> Result<String, AppError> {
let png_bytes = normalize_generated_image_to_png(image.bytes.as_slice())?;
Ok(format!(
"data:image/png;base64,{}",
BASE64_STANDARD.encode(png_bytes)
))
}
fn normalize_generated_image_to_png(source: &[u8]) -> Result<Vec<u8>, AppError> {
let rgba_image = image::load_from_memory(source)
.map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "vector-engine",
"message": format!("解析宝贝爱画魔法图片失败:{error}"),
}))
})?
.to_rgba8();
let (width, height) = rgba_image.dimensions();
let mut encoded = Vec::new();
let encoder = PngEncoder::new(&mut encoded);
encoder
.write_image(rgba_image.as_raw(), width, height, ColorType::Rgba8.into())
.map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "vector-engine",
"message": format!("转换宝贝爱画魔法图片为 PNG 失败:{error}"),
}))
})?;
Ok(encoded)
}
fn baby_love_drawing_error_response(request_context: &RequestContext, error: AppError) -> Response {
error.into_response_with_context(Some(request_context))
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_request() -> CreateBabyLoveDrawingMagicRequest {
CreateBabyLoveDrawingMagicRequest {
original_image_src: "data:image/png;base64,abcd".to_string(),
stroke_trace: vec![BabyLoveDrawingStrokePayload {
stroke_id: "stroke-1".to_string(),
tool: "brush".to_string(),
color: "#ef4444".to_string(),
points: vec![
BabyLoveDrawingPointPayload {
x: 0.2,
y: 0.3,
t: 1.0,
},
BabyLoveDrawingPointPayload {
x: 0.7,
y: 0.8,
t: 2.0,
},
],
}],
}
}
#[test]
fn magic_prompt_keeps_child_drawing_and_picture_book_style() {
let request = sample_request();
let prompt = build_baby_love_drawing_magic_prompt(request.stroke_trace.as_slice());
assert!(prompt.contains("宝贝爱画"));
assert!(prompt.contains("绘本风格"));
assert!(prompt.contains("保留小朋友原始画面"));
assert!(prompt.contains("#ef4444"));
assert!(prompt.contains("x 0.20-0.70"));
}
#[test]
fn magic_request_requires_image_data_url() {
let request = sample_request();
assert!(validate_magic_request(&request).is_ok());
let invalid = CreateBabyLoveDrawingMagicRequest {
original_image_src: "https://example.test/image.png".to_string(),
..sample_request()
};
assert!(validate_magic_request(&invalid).is_err());
}
#[test]
fn normalizes_png_to_png_data_url() {
let mut source = Vec::new();
let pixels = vec![255u8; 4 * 2 * 2];
let encoder = PngEncoder::new(&mut source);
encoder
.write_image(pixels.as_slice(), 2, 2, ColorType::Rgba8.into())
.expect("test png should encode");
let image_src = build_png_data_url(DownloadedOpenAiImage {
bytes: source,
mime_type: "image/png".to_string(),
extension: "png".to_string(),
})
.expect("test png should normalize");
assert!(image_src.starts_with("data:image/png;base64,"));
}
#[test]
fn trace_summary_ignores_invalid_points() {
let mut request = sample_request();
request.stroke_trace[0]
.points
.push(BabyLoveDrawingPointPayload {
x: f64::NAN,
y: 0.1,
t: 3.0,
});
assert_eq!(
summarize_trace_bounds(request.stroke_trace.as_slice()),
"x 0.20-0.70, y 0.30-0.80",
);
}
}