This commit is contained in:
233
server-rs/crates/api-server/src/ai_generation_drafts.rs
Normal file
233
server-rs/crates/api-server/src/ai_generation_drafts.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use module_ai::{
|
||||
AiTaskCreateInput, AiTaskKind, AiTaskStageBlueprint, AiTaskStageKind, AiTaskStageStartInput,
|
||||
AiTextChunkAppendInput,
|
||||
};
|
||||
use serde_json::json;
|
||||
use spacetime_client::{SpacetimeClient, SpacetimeClientError};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct AiGenerationDraftContext {
|
||||
pub task_id: String,
|
||||
pub owner_user_id: String,
|
||||
pub request_label: String,
|
||||
pub source_module: String,
|
||||
pub source_entity_id: String,
|
||||
pub template_key: String,
|
||||
pub operation_id: String,
|
||||
}
|
||||
|
||||
impl AiGenerationDraftContext {
|
||||
pub fn new(
|
||||
template_key: &str,
|
||||
owner_user_id: &str,
|
||||
session_id: &str,
|
||||
operation_id: &str,
|
||||
request_label: &str,
|
||||
) -> Self {
|
||||
let normalized_template = normalize_identifier_segment(template_key);
|
||||
let normalized_session = normalize_identifier_segment(session_id);
|
||||
let normalized_operation = normalize_identifier_segment(operation_id);
|
||||
|
||||
Self {
|
||||
// 生成过程草稿使用稳定 task_id,保证同一模板会话操作重试时能继续定位已有内容。
|
||||
task_id: format!(
|
||||
"aitask_draft_{normalized_template}_{normalized_session}_{normalized_operation}"
|
||||
),
|
||||
owner_user_id: owner_user_id.trim().to_string(),
|
||||
request_label: request_label.trim().to_string(),
|
||||
source_module: normalized_template,
|
||||
source_entity_id: session_id.trim().to_string(),
|
||||
template_key: template_key.trim().to_string(),
|
||||
operation_id: operation_id.trim().to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct AiGenerationDraftSink {
|
||||
context: AiGenerationDraftContext,
|
||||
client: SpacetimeClient,
|
||||
next_sequence: Arc<Mutex<u32>>,
|
||||
persisted_text: Arc<Mutex<String>>,
|
||||
}
|
||||
|
||||
impl AiGenerationDraftSink {
|
||||
pub fn new(context: AiGenerationDraftContext, client: SpacetimeClient) -> Self {
|
||||
Self {
|
||||
context,
|
||||
client,
|
||||
next_sequence: Arc::new(Mutex::new(1)),
|
||||
persisted_text: Arc::new(Mutex::new(String::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn persist_visible_text_async(&self, visible_text: &str) {
|
||||
let (sequence, delta_text) = {
|
||||
let mut persisted_text = self
|
||||
.persisted_text
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
let delta_text = visible_text
|
||||
.strip_prefix(persisted_text.as_str())
|
||||
.unwrap_or(visible_text)
|
||||
.to_string();
|
||||
*persisted_text = visible_text.to_string();
|
||||
if delta_text.trim().is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut next_sequence = self
|
||||
.next_sequence
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
let sequence = *next_sequence;
|
||||
*next_sequence = next_sequence.saturating_add(1);
|
||||
(sequence, delta_text)
|
||||
};
|
||||
let context = self.context.clone();
|
||||
let client = self.client.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(error) = client
|
||||
.append_ai_text_chunk(AiTextChunkAppendInput {
|
||||
task_id: context.task_id.clone(),
|
||||
stage_kind: AiTaskStageKind::RequestModel,
|
||||
sequence,
|
||||
delta_text,
|
||||
created_at_micros: current_utc_micros(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
task_id = %context.task_id,
|
||||
sequence,
|
||||
error = %error,
|
||||
"AI 生成草稿后台增量落库失败,主生成流程继续执行"
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AiGenerationDraftWriter {
|
||||
context: AiGenerationDraftContext,
|
||||
next_sequence: u32,
|
||||
persisted_text: String,
|
||||
}
|
||||
|
||||
impl AiGenerationDraftWriter {
|
||||
pub fn new(context: AiGenerationDraftContext) -> Self {
|
||||
Self {
|
||||
context,
|
||||
next_sequence: 1,
|
||||
persisted_text: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ensure_started(
|
||||
&mut self,
|
||||
client: &SpacetimeClient,
|
||||
) -> Result<(), SpacetimeClientError> {
|
||||
let now_micros = current_utc_micros();
|
||||
match client
|
||||
.create_ai_task(AiTaskCreateInput {
|
||||
task_id: self.context.task_id.clone(),
|
||||
task_kind: AiTaskKind::CustomWorldGeneration,
|
||||
owner_user_id: self.context.owner_user_id.clone(),
|
||||
request_label: self.context.request_label.clone(),
|
||||
source_module: self.context.source_module.clone(),
|
||||
source_entity_id: Some(self.context.source_entity_id.clone()),
|
||||
request_payload_json: Some(
|
||||
json!({
|
||||
"templateKey": self.context.template_key,
|
||||
"operationId": self.context.operation_id,
|
||||
})
|
||||
.to_string(),
|
||||
),
|
||||
stages: vec![AiTaskStageBlueprint {
|
||||
stage_kind: AiTaskStageKind::RequestModel,
|
||||
label: "请求模型".to_string(),
|
||||
detail: "模板生成过程中持续写入模型已生成文本。".to_string(),
|
||||
order: 1,
|
||||
}],
|
||||
created_at_micros: now_micros,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(error) if is_duplicate_ai_task_error(&error) => {}
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
|
||||
client
|
||||
.start_ai_task_stage(AiTaskStageStartInput {
|
||||
task_id: self.context.task_id.clone(),
|
||||
stage_kind: AiTaskStageKind::RequestModel,
|
||||
started_at_micros: now_micros,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn persist_visible_text(&mut self, client: &SpacetimeClient, visible_text: &str) {
|
||||
let delta_text = match visible_text.strip_prefix(self.persisted_text.as_str()) {
|
||||
Some(delta) => delta,
|
||||
None => visible_text,
|
||||
};
|
||||
if delta_text.trim().is_empty() {
|
||||
self.persisted_text = visible_text.to_string();
|
||||
return;
|
||||
}
|
||||
|
||||
let sequence = self.next_sequence;
|
||||
self.next_sequence = self.next_sequence.saturating_add(1);
|
||||
self.persisted_text = visible_text.to_string();
|
||||
|
||||
if let Err(error) = client
|
||||
.append_ai_text_chunk(AiTextChunkAppendInput {
|
||||
task_id: self.context.task_id.clone(),
|
||||
stage_kind: AiTaskStageKind::RequestModel,
|
||||
sequence,
|
||||
delta_text: delta_text.to_string(),
|
||||
created_at_micros: current_utc_micros(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
task_id = %self.context.task_id,
|
||||
sequence,
|
||||
error = %error,
|
||||
"AI 生成草稿增量落库失败,主生成流程继续执行"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_identifier_segment(value: &str) -> String {
|
||||
let normalized = value
|
||||
.trim()
|
||||
.chars()
|
||||
.map(|character| {
|
||||
if character.is_ascii_alphanumeric() || character == '-' || character == '_' {
|
||||
character
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
if normalized.is_empty() {
|
||||
"unknown".to_string()
|
||||
} else {
|
||||
normalized
|
||||
}
|
||||
}
|
||||
|
||||
fn is_duplicate_ai_task_error(error: &SpacetimeClientError) -> bool {
|
||||
error.to_string().contains("ai_task.task_id 已存在")
|
||||
}
|
||||
|
||||
fn current_utc_micros() -> i64 {
|
||||
time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000
|
||||
}
|
||||
Reference in New Issue
Block a user