use module_ai::{ AiTaskCreateInput, AiTaskKind, AiTaskStageBlueprint, AiTaskStageKind, AiTaskStageStartInput, AiTextChunkAppendInput, }; use serde_json::json; use spacetime_client::{SpacetimeClient, SpacetimeClientError}; use std::sync::{Arc, Mutex}; use tracing::warn; #[derive(Clone, Debug)] pub(crate) struct AiGenerationDraftContext { pub task_id: String, pub owner_user_id: String, pub request_label: String, pub source_module: String, pub source_entity_id: String, pub template_key: String, pub operation_id: String, } impl AiGenerationDraftContext { pub fn new( template_key: &str, owner_user_id: &str, session_id: &str, operation_id: &str, request_label: &str, ) -> Self { let normalized_template = normalize_identifier_segment(template_key); let normalized_session = normalize_identifier_segment(session_id); let normalized_operation = normalize_identifier_segment(operation_id); Self { // 生成过程草稿使用稳定 task_id,保证同一模板会话操作重试时能继续定位已有内容。 task_id: format!( "aitask_draft_{normalized_template}_{normalized_session}_{normalized_operation}" ), owner_user_id: owner_user_id.trim().to_string(), request_label: request_label.trim().to_string(), source_module: normalized_template, source_entity_id: session_id.trim().to_string(), template_key: template_key.trim().to_string(), operation_id: operation_id.trim().to_string(), } } } #[derive(Clone, Debug)] pub(crate) struct AiGenerationDraftSink { context: AiGenerationDraftContext, client: SpacetimeClient, next_sequence: Arc>, persisted_text: Arc>, } impl AiGenerationDraftSink { pub fn new(context: AiGenerationDraftContext, client: SpacetimeClient) -> Self { Self { context, client, next_sequence: Arc::new(Mutex::new(1)), persisted_text: Arc::new(Mutex::new(String::new())), } } pub fn persist_visible_text_async(&self, visible_text: &str) { let (sequence, delta_text) = { let mut persisted_text = self .persisted_text .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); let delta_text = visible_text .strip_prefix(persisted_text.as_str()) .unwrap_or(visible_text) .to_string(); *persisted_text = visible_text.to_string(); if delta_text.trim().is_empty() { return; } let mut next_sequence = self .next_sequence .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); let sequence = *next_sequence; *next_sequence = next_sequence.saturating_add(1); (sequence, delta_text) }; let context = self.context.clone(); let client = self.client.clone(); tokio::spawn(async move { if let Err(error) = client .append_ai_text_chunk(AiTextChunkAppendInput { task_id: context.task_id.clone(), stage_kind: AiTaskStageKind::RequestModel, sequence, delta_text, created_at_micros: current_utc_micros(), }) .await { warn!( task_id = %context.task_id, sequence, error = %error, "AI 生成草稿后台增量落库失败,主生成流程继续执行" ); } }); } } #[derive(Debug)] pub(crate) struct AiGenerationDraftWriter { context: AiGenerationDraftContext, next_sequence: u32, persisted_text: String, } impl AiGenerationDraftWriter { pub fn new(context: AiGenerationDraftContext) -> Self { Self { context, next_sequence: 1, persisted_text: String::new(), } } pub async fn ensure_started( &mut self, client: &SpacetimeClient, ) -> Result<(), SpacetimeClientError> { let now_micros = current_utc_micros(); match client .create_ai_task(AiTaskCreateInput { task_id: self.context.task_id.clone(), task_kind: AiTaskKind::CustomWorldGeneration, owner_user_id: self.context.owner_user_id.clone(), request_label: self.context.request_label.clone(), source_module: self.context.source_module.clone(), source_entity_id: Some(self.context.source_entity_id.clone()), request_payload_json: Some( json!({ "templateKey": self.context.template_key, "operationId": self.context.operation_id, }) .to_string(), ), stages: vec![AiTaskStageBlueprint { stage_kind: AiTaskStageKind::RequestModel, label: "请求模型".to_string(), detail: "模板生成过程中持续写入模型已生成文本。".to_string(), order: 1, }], created_at_micros: now_micros, }) .await { Ok(_) => {} Err(error) if is_duplicate_ai_task_error(&error) => {} Err(error) => return Err(error), } client .start_ai_task_stage(AiTaskStageStartInput { task_id: self.context.task_id.clone(), stage_kind: AiTaskStageKind::RequestModel, started_at_micros: now_micros, }) .await } pub async fn persist_visible_text(&mut self, client: &SpacetimeClient, visible_text: &str) { let delta_text = match visible_text.strip_prefix(self.persisted_text.as_str()) { Some(delta) => delta, None => visible_text, }; if delta_text.trim().is_empty() { self.persisted_text = visible_text.to_string(); return; } let sequence = self.next_sequence; self.next_sequence = self.next_sequence.saturating_add(1); self.persisted_text = visible_text.to_string(); if let Err(error) = client .append_ai_text_chunk(AiTextChunkAppendInput { task_id: self.context.task_id.clone(), stage_kind: AiTaskStageKind::RequestModel, sequence, delta_text: delta_text.to_string(), created_at_micros: current_utc_micros(), }) .await { warn!( task_id = %self.context.task_id, sequence, error = %error, "AI 生成草稿增量落库失败,主生成流程继续执行" ); } } } fn normalize_identifier_segment(value: &str) -> String { let normalized = value .trim() .chars() .map(|character| { if character.is_ascii_alphanumeric() || character == '-' || character == '_' { character } else { '_' } }) .collect::(); if normalized.is_empty() { "unknown".to_string() } else { normalized } } fn is_duplicate_ai_task_error(error: &SpacetimeClientError) -> bool { error.to_string().contains("ai_task.task_id 已存在") } fn current_utc_micros() -> i64 { time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000 }