234 lines
7.7 KiB
Rust
234 lines
7.7 KiB
Rust
use module_ai::{
|
||
AiTaskCreateInput, AiTaskKind, AiTaskStageBlueprint, AiTaskStageKind, AiTaskStageStartInput,
|
||
AiTextChunkAppendInput,
|
||
};
|
||
use serde_json::json;
|
||
use spacetime_client::{SpacetimeClient, SpacetimeClientError};
|
||
use std::sync::{Arc, Mutex};
|
||
use tracing::warn;
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub(crate) struct AiGenerationDraftContext {
|
||
pub task_id: String,
|
||
pub owner_user_id: String,
|
||
pub request_label: String,
|
||
pub source_module: String,
|
||
pub source_entity_id: String,
|
||
pub template_key: String,
|
||
pub operation_id: String,
|
||
}
|
||
|
||
impl AiGenerationDraftContext {
|
||
pub fn new(
|
||
template_key: &str,
|
||
owner_user_id: &str,
|
||
session_id: &str,
|
||
operation_id: &str,
|
||
request_label: &str,
|
||
) -> Self {
|
||
let normalized_template = normalize_identifier_segment(template_key);
|
||
let normalized_session = normalize_identifier_segment(session_id);
|
||
let normalized_operation = normalize_identifier_segment(operation_id);
|
||
|
||
Self {
|
||
// 生成过程草稿使用稳定 task_id,保证同一模板会话操作重试时能继续定位已有内容。
|
||
task_id: format!(
|
||
"aitask_draft_{normalized_template}_{normalized_session}_{normalized_operation}"
|
||
),
|
||
owner_user_id: owner_user_id.trim().to_string(),
|
||
request_label: request_label.trim().to_string(),
|
||
source_module: normalized_template,
|
||
source_entity_id: session_id.trim().to_string(),
|
||
template_key: template_key.trim().to_string(),
|
||
operation_id: operation_id.trim().to_string(),
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub(crate) struct AiGenerationDraftSink {
|
||
context: AiGenerationDraftContext,
|
||
client: SpacetimeClient,
|
||
next_sequence: Arc<Mutex<u32>>,
|
||
persisted_text: Arc<Mutex<String>>,
|
||
}
|
||
|
||
impl AiGenerationDraftSink {
|
||
pub fn new(context: AiGenerationDraftContext, client: SpacetimeClient) -> Self {
|
||
Self {
|
||
context,
|
||
client,
|
||
next_sequence: Arc::new(Mutex::new(1)),
|
||
persisted_text: Arc::new(Mutex::new(String::new())),
|
||
}
|
||
}
|
||
|
||
pub fn persist_visible_text_async(&self, visible_text: &str) {
|
||
let (sequence, delta_text) = {
|
||
let mut persisted_text = self
|
||
.persisted_text
|
||
.lock()
|
||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||
let delta_text = visible_text
|
||
.strip_prefix(persisted_text.as_str())
|
||
.unwrap_or(visible_text)
|
||
.to_string();
|
||
*persisted_text = visible_text.to_string();
|
||
if delta_text.trim().is_empty() {
|
||
return;
|
||
}
|
||
|
||
let mut next_sequence = self
|
||
.next_sequence
|
||
.lock()
|
||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||
let sequence = *next_sequence;
|
||
*next_sequence = next_sequence.saturating_add(1);
|
||
(sequence, delta_text)
|
||
};
|
||
let context = self.context.clone();
|
||
let client = self.client.clone();
|
||
tokio::spawn(async move {
|
||
if let Err(error) = client
|
||
.append_ai_text_chunk(AiTextChunkAppendInput {
|
||
task_id: context.task_id.clone(),
|
||
stage_kind: AiTaskStageKind::RequestModel,
|
||
sequence,
|
||
delta_text,
|
||
created_at_micros: current_utc_micros(),
|
||
})
|
||
.await
|
||
{
|
||
warn!(
|
||
task_id = %context.task_id,
|
||
sequence,
|
||
error = %error,
|
||
"AI 生成草稿后台增量落库失败,主生成流程继续执行"
|
||
);
|
||
}
|
||
});
|
||
}
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
pub(crate) struct AiGenerationDraftWriter {
|
||
context: AiGenerationDraftContext,
|
||
next_sequence: u32,
|
||
persisted_text: String,
|
||
}
|
||
|
||
impl AiGenerationDraftWriter {
|
||
pub fn new(context: AiGenerationDraftContext) -> Self {
|
||
Self {
|
||
context,
|
||
next_sequence: 1,
|
||
persisted_text: String::new(),
|
||
}
|
||
}
|
||
|
||
pub async fn ensure_started(
|
||
&mut self,
|
||
client: &SpacetimeClient,
|
||
) -> Result<(), SpacetimeClientError> {
|
||
let now_micros = current_utc_micros();
|
||
match client
|
||
.create_ai_task(AiTaskCreateInput {
|
||
task_id: self.context.task_id.clone(),
|
||
task_kind: AiTaskKind::CustomWorldGeneration,
|
||
owner_user_id: self.context.owner_user_id.clone(),
|
||
request_label: self.context.request_label.clone(),
|
||
source_module: self.context.source_module.clone(),
|
||
source_entity_id: Some(self.context.source_entity_id.clone()),
|
||
request_payload_json: Some(
|
||
json!({
|
||
"templateKey": self.context.template_key,
|
||
"operationId": self.context.operation_id,
|
||
})
|
||
.to_string(),
|
||
),
|
||
stages: vec![AiTaskStageBlueprint {
|
||
stage_kind: AiTaskStageKind::RequestModel,
|
||
label: "请求模型".to_string(),
|
||
detail: "模板生成过程中持续写入模型已生成文本。".to_string(),
|
||
order: 1,
|
||
}],
|
||
created_at_micros: now_micros,
|
||
})
|
||
.await
|
||
{
|
||
Ok(_) => {}
|
||
Err(error) if is_duplicate_ai_task_error(&error) => {}
|
||
Err(error) => return Err(error),
|
||
}
|
||
|
||
client
|
||
.start_ai_task_stage(AiTaskStageStartInput {
|
||
task_id: self.context.task_id.clone(),
|
||
stage_kind: AiTaskStageKind::RequestModel,
|
||
started_at_micros: now_micros,
|
||
})
|
||
.await
|
||
}
|
||
|
||
pub async fn persist_visible_text(&mut self, client: &SpacetimeClient, visible_text: &str) {
|
||
let delta_text = match visible_text.strip_prefix(self.persisted_text.as_str()) {
|
||
Some(delta) => delta,
|
||
None => visible_text,
|
||
};
|
||
if delta_text.trim().is_empty() {
|
||
self.persisted_text = visible_text.to_string();
|
||
return;
|
||
}
|
||
|
||
let sequence = self.next_sequence;
|
||
self.next_sequence = self.next_sequence.saturating_add(1);
|
||
self.persisted_text = visible_text.to_string();
|
||
|
||
if let Err(error) = client
|
||
.append_ai_text_chunk(AiTextChunkAppendInput {
|
||
task_id: self.context.task_id.clone(),
|
||
stage_kind: AiTaskStageKind::RequestModel,
|
||
sequence,
|
||
delta_text: delta_text.to_string(),
|
||
created_at_micros: current_utc_micros(),
|
||
})
|
||
.await
|
||
{
|
||
warn!(
|
||
task_id = %self.context.task_id,
|
||
sequence,
|
||
error = %error,
|
||
"AI 生成草稿增量落库失败,主生成流程继续执行"
|
||
);
|
||
}
|
||
}
|
||
}
|
||
|
||
fn normalize_identifier_segment(value: &str) -> String {
|
||
let normalized = value
|
||
.trim()
|
||
.chars()
|
||
.map(|character| {
|
||
if character.is_ascii_alphanumeric() || character == '-' || character == '_' {
|
||
character
|
||
} else {
|
||
'_'
|
||
}
|
||
})
|
||
.collect::<String>();
|
||
|
||
if normalized.is_empty() {
|
||
"unknown".to_string()
|
||
} else {
|
||
normalized
|
||
}
|
||
}
|
||
|
||
fn is_duplicate_ai_task_error(error: &SpacetimeClientError) -> bool {
|
||
error.to_string().contains("ai_task.task_id 已存在")
|
||
}
|
||
|
||
fn current_utc_micros() -> i64 {
|
||
time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000
|
||
}
|