use axum::{ Json, extract::{Extension, State, rejection::JsonRejection}, http::StatusCode, response::Response, }; use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD}; use image::{ColorType, ImageEncoder, codecs::png::PngEncoder}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use crate::{ api_response::json_success_body, http_error::AppError, openai_image_generation::{ DownloadedOpenAiImage, build_openai_image_http_client, create_openai_image_generation, require_openai_image_settings, }, request_context::RequestContext, state::AppState, }; const BABY_LOVE_DRAWING_PROVIDER: &str = "vector-engine-gpt-image-2"; const BABY_LOVE_DRAWING_IMAGE_SIZE: &str = "1024x1024"; const BABY_LOVE_DRAWING_MAX_STROKES: usize = 600; #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub(crate) struct CreateBabyLoveDrawingMagicRequest { original_image_src: String, #[serde(default)] stroke_trace: Vec, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct BabyLoveDrawingStrokePayload { stroke_id: String, tool: String, color: String, #[serde(default)] points: Vec, } #[derive(Debug, Deserialize)] struct BabyLoveDrawingPointPayload { x: f64, y: f64, t: f64, } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] struct CreateBabyLoveDrawingMagicResponse { magic_image_src: String, generation_provider: String, prompt: String, } pub async fn create_baby_love_drawing_magic( State(state): State, Extension(request_context): Extension, payload: Result, JsonRejection>, ) -> Result, Response> { let Json(payload) = payload.map_err(|error| { baby_love_drawing_error_response( &request_context, AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({ "provider": "edutainment-baby-drawing", "message": error.body_text(), })), ) })?; validate_magic_request(&payload) .map_err(|error| baby_love_drawing_error_response(&request_context, error))?; let settings = require_openai_image_settings(&state) .map_err(|error| baby_love_drawing_error_response(&request_context, error))?; let http_client = build_openai_image_http_client(&settings) .map_err(|error| baby_love_drawing_error_response(&request_context, error))?; let prompt = build_baby_love_drawing_magic_prompt(payload.stroke_trace.as_slice()); let reference_images = vec![payload.original_image_src.trim().to_string()]; let generated = create_openai_image_generation( &http_client, &settings, prompt.as_str(), Some(build_baby_love_drawing_negative_prompt()), BABY_LOVE_DRAWING_IMAGE_SIZE, 1, reference_images.as_slice(), "宝贝爱画绘画魔法图片生成失败", ) .await .map_err(|error| baby_love_drawing_error_response(&request_context, error))?; let generated_image = generated.images.into_iter().next().ok_or_else(|| { baby_love_drawing_error_response( &request_context, AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({ "provider": "vector-engine", "message": "宝贝爱画绘画魔法没有返回图片。", })), ) })?; let magic_image_src = build_png_data_url(generated_image) .map_err(|error| baby_love_drawing_error_response(&request_context, error))?; Ok(json_success_body( Some(&request_context), CreateBabyLoveDrawingMagicResponse { magic_image_src, generation_provider: BABY_LOVE_DRAWING_PROVIDER.to_string(), prompt, }, )) } fn validate_magic_request(payload: &CreateBabyLoveDrawingMagicRequest) -> Result<(), AppError> { let original_image_src = payload.original_image_src.trim(); if !original_image_src.starts_with("data:image/") || !original_image_src.contains(";base64,") { return Err( AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({ "provider": "edutainment-baby-drawing", "message": "绘画原图必须是图片 Data URL。", })), ); } if payload.stroke_trace.len() > BABY_LOVE_DRAWING_MAX_STROKES { return Err( AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({ "provider": "edutainment-baby-drawing", "message": "绘画笔触数量过多,请重新完成绘画后再使用魔法。", })), ); } Ok(()) } fn build_baby_love_drawing_magic_prompt(stroke_trace: &[BabyLoveDrawingStrokePayload]) -> String { let stroke_count = stroke_trace.len(); let brush_count = stroke_trace .iter() .filter(|stroke| stroke.tool.trim() == "brush") .count(); let eraser_count = stroke_trace .iter() .filter(|stroke| stroke.tool.trim() == "eraser") .count(); let color_summary = summarize_stroke_colors(stroke_trace); let trace_bounds = summarize_trace_bounds(stroke_trace); format!( "根据参考图中的儿童绘画内容,为寓教于乐独立关卡“宝贝爱画”生成一张绘本风格图片。\n\ 必须保留小朋友原始画面的主体构图、线条方向、颜色关系和童趣笔触,不要改成与原图无关的新内容。\n\ 输出风格:明亮、温暖、柔和、卡通绘本风格,适合 4-8 岁儿童,画面干净,边缘柔和,有轻微纸面质感。\n\ 笔触信息:总笔触 {stroke_count} 条,画笔 {brush_count} 条,橡皮 {eraser_count} 条,主要颜色 {color_summary},绘制范围 {trace_bounds}。\n\ 不要生成文字、水印、Logo、按钮、UI 面板、真实照片风、恐怖或成人化内容。" ) } fn summarize_stroke_colors(stroke_trace: &[BabyLoveDrawingStrokePayload]) -> String { let mut colors = Vec::new(); for stroke in stroke_trace { if stroke.stroke_id.trim().is_empty() { continue; } let color = stroke.color.trim(); if color.is_empty() || colors.iter().any(|value| value == color) { continue; } colors.push(color.to_string()); if colors.len() >= 5 { break; } } if colors.is_empty() { "无明显颜色记录".to_string() } else { colors.join("、") } } fn summarize_trace_bounds(stroke_trace: &[BabyLoveDrawingStrokePayload]) -> String { let mut min_x = 1.0_f64; let mut min_y = 1.0_f64; let mut max_x = 0.0_f64; let mut max_y = 0.0_f64; let mut has_point = false; for point in stroke_trace.iter().flat_map(|stroke| stroke.points.iter()) { if !(point.x.is_finite() && point.y.is_finite() && point.t.is_finite()) { continue; } has_point = true; min_x = min_x.min(point.x.clamp(0.0, 1.0)); min_y = min_y.min(point.y.clamp(0.0, 1.0)); max_x = max_x.max(point.x.clamp(0.0, 1.0)); max_y = max_y.max(point.y.clamp(0.0, 1.0)); } if !has_point { return "无可用坐标记录".to_string(); } format!("x {:.2}-{:.2}, y {:.2}-{:.2}", min_x, max_x, min_y, max_y) } fn build_baby_love_drawing_negative_prompt() -> &'static str { "文字,水印,Logo,按钮,UI,面板,复杂背景,真实照片风,恐怖元素,成人化内容,攻击性内容,替换原图主体,完全无关的新画面" } fn build_png_data_url(image: DownloadedOpenAiImage) -> Result { let png_bytes = normalize_generated_image_to_png(image.bytes.as_slice())?; Ok(format!( "data:image/png;base64,{}", BASE64_STANDARD.encode(png_bytes) )) } fn normalize_generated_image_to_png(source: &[u8]) -> Result, AppError> { let rgba_image = image::load_from_memory(source) .map_err(|error| { AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({ "provider": "vector-engine", "message": format!("解析宝贝爱画魔法图片失败:{error}"), })) })? .to_rgba8(); let (width, height) = rgba_image.dimensions(); let mut encoded = Vec::new(); let encoder = PngEncoder::new(&mut encoded); encoder .write_image(rgba_image.as_raw(), width, height, ColorType::Rgba8.into()) .map_err(|error| { AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({ "provider": "vector-engine", "message": format!("转换宝贝爱画魔法图片为 PNG 失败:{error}"), })) })?; Ok(encoded) } fn baby_love_drawing_error_response(request_context: &RequestContext, error: AppError) -> Response { error.into_response_with_context(Some(request_context)) } #[cfg(test)] mod tests { use super::*; fn sample_request() -> CreateBabyLoveDrawingMagicRequest { CreateBabyLoveDrawingMagicRequest { original_image_src: "data:image/png;base64,abcd".to_string(), stroke_trace: vec![BabyLoveDrawingStrokePayload { stroke_id: "stroke-1".to_string(), tool: "brush".to_string(), color: "#ef4444".to_string(), points: vec![ BabyLoveDrawingPointPayload { x: 0.2, y: 0.3, t: 1.0, }, BabyLoveDrawingPointPayload { x: 0.7, y: 0.8, t: 2.0, }, ], }], } } #[test] fn magic_prompt_keeps_child_drawing_and_picture_book_style() { let request = sample_request(); let prompt = build_baby_love_drawing_magic_prompt(request.stroke_trace.as_slice()); assert!(prompt.contains("宝贝爱画")); assert!(prompt.contains("绘本风格")); assert!(prompt.contains("保留小朋友原始画面")); assert!(prompt.contains("#ef4444")); assert!(prompt.contains("x 0.20-0.70")); } #[test] fn magic_request_requires_image_data_url() { let request = sample_request(); assert!(validate_magic_request(&request).is_ok()); let invalid = CreateBabyLoveDrawingMagicRequest { original_image_src: "https://example.test/image.png".to_string(), ..sample_request() }; assert!(validate_magic_request(&invalid).is_err()); } #[test] fn normalizes_png_to_png_data_url() { let mut source = Vec::new(); let pixels = vec![255u8; 4 * 2 * 2]; let encoder = PngEncoder::new(&mut source); encoder .write_image(pixels.as_slice(), 2, 2, ColorType::Rgba8.into()) .expect("test png should encode"); let image_src = build_png_data_url(DownloadedOpenAiImage { bytes: source, mime_type: "image/png".to_string(), extension: "png".to_string(), }) .expect("test png should normalize"); assert!(image_src.starts_with("data:image/png;base64,")); } #[test] fn trace_summary_ignores_invalid_points() { let mut request = sample_request(); request.stroke_trace[0] .points .push(BabyLoveDrawingPointPayload { x: f64::NAN, y: 0.1, t: 3.0, }); assert_eq!( summarize_trace_bounds(request.stroke_trace.as_slice()), "x 0.20-0.70, y 0.30-0.80", ); } }