Files
Genarrative/server-rs/crates/api-server/src/ai_generation_drafts.rs
kdletters cbc27bad4a
Some checks failed
CI / verify (push) Has been cancelled
init with react+axum+spacetimedb
2026-04-26 18:06:23 +08:00

234 lines
7.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use module_ai::{
AiTaskCreateInput, AiTaskKind, AiTaskStageBlueprint, AiTaskStageKind, AiTaskStageStartInput,
AiTextChunkAppendInput,
};
use serde_json::json;
use spacetime_client::{SpacetimeClient, SpacetimeClientError};
use std::sync::{Arc, Mutex};
use tracing::warn;
#[derive(Clone, Debug)]
pub(crate) struct AiGenerationDraftContext {
pub task_id: String,
pub owner_user_id: String,
pub request_label: String,
pub source_module: String,
pub source_entity_id: String,
pub template_key: String,
pub operation_id: String,
}
impl AiGenerationDraftContext {
pub fn new(
template_key: &str,
owner_user_id: &str,
session_id: &str,
operation_id: &str,
request_label: &str,
) -> Self {
let normalized_template = normalize_identifier_segment(template_key);
let normalized_session = normalize_identifier_segment(session_id);
let normalized_operation = normalize_identifier_segment(operation_id);
Self {
// 生成过程草稿使用稳定 task_id保证同一模板会话操作重试时能继续定位已有内容。
task_id: format!(
"aitask_draft_{normalized_template}_{normalized_session}_{normalized_operation}"
),
owner_user_id: owner_user_id.trim().to_string(),
request_label: request_label.trim().to_string(),
source_module: normalized_template,
source_entity_id: session_id.trim().to_string(),
template_key: template_key.trim().to_string(),
operation_id: operation_id.trim().to_string(),
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct AiGenerationDraftSink {
context: AiGenerationDraftContext,
client: SpacetimeClient,
next_sequence: Arc<Mutex<u32>>,
persisted_text: Arc<Mutex<String>>,
}
impl AiGenerationDraftSink {
pub fn new(context: AiGenerationDraftContext, client: SpacetimeClient) -> Self {
Self {
context,
client,
next_sequence: Arc::new(Mutex::new(1)),
persisted_text: Arc::new(Mutex::new(String::new())),
}
}
pub fn persist_visible_text_async(&self, visible_text: &str) {
let (sequence, delta_text) = {
let mut persisted_text = self
.persisted_text
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let delta_text = visible_text
.strip_prefix(persisted_text.as_str())
.unwrap_or(visible_text)
.to_string();
*persisted_text = visible_text.to_string();
if delta_text.trim().is_empty() {
return;
}
let mut next_sequence = self
.next_sequence
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let sequence = *next_sequence;
*next_sequence = next_sequence.saturating_add(1);
(sequence, delta_text)
};
let context = self.context.clone();
let client = self.client.clone();
tokio::spawn(async move {
if let Err(error) = client
.append_ai_text_chunk(AiTextChunkAppendInput {
task_id: context.task_id.clone(),
stage_kind: AiTaskStageKind::RequestModel,
sequence,
delta_text,
created_at_micros: current_utc_micros(),
})
.await
{
warn!(
task_id = %context.task_id,
sequence,
error = %error,
"AI 生成草稿后台增量落库失败,主生成流程继续执行"
);
}
});
}
}
#[derive(Debug)]
pub(crate) struct AiGenerationDraftWriter {
context: AiGenerationDraftContext,
next_sequence: u32,
persisted_text: String,
}
impl AiGenerationDraftWriter {
pub fn new(context: AiGenerationDraftContext) -> Self {
Self {
context,
next_sequence: 1,
persisted_text: String::new(),
}
}
pub async fn ensure_started(
&mut self,
client: &SpacetimeClient,
) -> Result<(), SpacetimeClientError> {
let now_micros = current_utc_micros();
match client
.create_ai_task(AiTaskCreateInput {
task_id: self.context.task_id.clone(),
task_kind: AiTaskKind::CustomWorldGeneration,
owner_user_id: self.context.owner_user_id.clone(),
request_label: self.context.request_label.clone(),
source_module: self.context.source_module.clone(),
source_entity_id: Some(self.context.source_entity_id.clone()),
request_payload_json: Some(
json!({
"templateKey": self.context.template_key,
"operationId": self.context.operation_id,
})
.to_string(),
),
stages: vec![AiTaskStageBlueprint {
stage_kind: AiTaskStageKind::RequestModel,
label: "请求模型".to_string(),
detail: "模板生成过程中持续写入模型已生成文本。".to_string(),
order: 1,
}],
created_at_micros: now_micros,
})
.await
{
Ok(_) => {}
Err(error) if is_duplicate_ai_task_error(&error) => {}
Err(error) => return Err(error),
}
client
.start_ai_task_stage(AiTaskStageStartInput {
task_id: self.context.task_id.clone(),
stage_kind: AiTaskStageKind::RequestModel,
started_at_micros: now_micros,
})
.await
}
pub async fn persist_visible_text(&mut self, client: &SpacetimeClient, visible_text: &str) {
let delta_text = match visible_text.strip_prefix(self.persisted_text.as_str()) {
Some(delta) => delta,
None => visible_text,
};
if delta_text.trim().is_empty() {
self.persisted_text = visible_text.to_string();
return;
}
let sequence = self.next_sequence;
self.next_sequence = self.next_sequence.saturating_add(1);
self.persisted_text = visible_text.to_string();
if let Err(error) = client
.append_ai_text_chunk(AiTextChunkAppendInput {
task_id: self.context.task_id.clone(),
stage_kind: AiTaskStageKind::RequestModel,
sequence,
delta_text: delta_text.to_string(),
created_at_micros: current_utc_micros(),
})
.await
{
warn!(
task_id = %self.context.task_id,
sequence,
error = %error,
"AI 生成草稿增量落库失败,主生成流程继续执行"
);
}
}
}
fn normalize_identifier_segment(value: &str) -> String {
let normalized = value
.trim()
.chars()
.map(|character| {
if character.is_ascii_alphanumeric() || character == '-' || character == '_' {
character
} else {
'_'
}
})
.collect::<String>();
if normalized.is_empty() {
"unknown".to_string()
} else {
normalized
}
}
fn is_duplicate_ai_task_error(error: &SpacetimeClientError) -> bool {
error.to_string().contains("ai_task.task_id 已存在")
}
fn current_utc_micros() -> i64 {
time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000
}