Integrate unfinished server-rs refactor worklists
This commit is contained in:
14
server-rs/crates/module-ai/src/application/result.rs
Normal file
14
server-rs/crates/module-ai/src/application/result.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(feature = "spacetime-types")]
|
||||
use spacetimedb::SpacetimeType;
|
||||
|
||||
use crate::{AiTaskSnapshot, AiTextChunkSnapshot};
|
||||
|
||||
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct AiTaskProcedureResult {
|
||||
pub ok: bool,
|
||||
pub task: Option<AiTaskSnapshot>,
|
||||
pub text_chunk: Option<AiTextChunkSnapshot>,
|
||||
pub error_message: Option<String>,
|
||||
}
|
||||
250
server-rs/crates/module-ai/src/application/service.rs
Normal file
250
server-rs/crates/module-ai/src/application/service.rs
Normal file
@@ -0,0 +1,250 @@
|
||||
use shared_kernel::normalize_required_string;
|
||||
|
||||
use crate::commands::validate_task_create_input;
|
||||
use crate::{
|
||||
AiResultReferenceKind, AiResultReferenceSnapshot, AiStageCompletionInput, AiTaskCreateInput,
|
||||
AiTaskFieldError, AiTaskServiceError, AiTaskSnapshot, AiTaskStageKind, AiTaskStageSnapshot,
|
||||
AiTaskStageStatus, AiTaskStatus, AiTextChunkSnapshot, INITIAL_AI_TASK_VERSION,
|
||||
generate_ai_result_ref_id, generate_ai_text_chunk_id, normalize_optional_text,
|
||||
normalize_string_list,
|
||||
};
|
||||
|
||||
use super::{InMemoryAiTaskStore, ensure_task_is_not_terminal};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AiTaskService {
|
||||
store: InMemoryAiTaskStore,
|
||||
}
|
||||
|
||||
impl AiTaskService {
|
||||
pub fn new(store: InMemoryAiTaskStore) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
|
||||
pub fn create_task(
|
||||
&self,
|
||||
input: AiTaskCreateInput,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
validate_task_create_input(&input).map_err(AiTaskServiceError::Field)?;
|
||||
|
||||
let snapshot = AiTaskSnapshot {
|
||||
task_id: input.task_id.clone(),
|
||||
task_kind: input.task_kind,
|
||||
owner_user_id: normalize_required_string(input.owner_user_id).unwrap_or_default(),
|
||||
request_label: normalize_required_string(input.request_label).unwrap_or_default(),
|
||||
source_module: normalize_required_string(input.source_module).unwrap_or_default(),
|
||||
source_entity_id: normalize_optional_text(input.source_entity_id),
|
||||
request_payload_json: normalize_optional_text(input.request_payload_json),
|
||||
status: AiTaskStatus::Pending,
|
||||
failure_message: None,
|
||||
stages: input
|
||||
.stages
|
||||
.into_iter()
|
||||
.map(|stage| AiTaskStageSnapshot {
|
||||
stage_kind: stage.stage_kind,
|
||||
label: normalize_required_string(stage.label).unwrap_or_default(),
|
||||
detail: normalize_required_string(stage.detail).unwrap_or_default(),
|
||||
order: stage.order,
|
||||
status: AiTaskStageStatus::Pending,
|
||||
text_output: None,
|
||||
structured_payload_json: None,
|
||||
warning_messages: Vec::new(),
|
||||
started_at_micros: None,
|
||||
completed_at_micros: None,
|
||||
})
|
||||
.collect(),
|
||||
result_references: Vec::new(),
|
||||
latest_text_output: None,
|
||||
latest_structured_payload_json: None,
|
||||
version: INITIAL_AI_TASK_VERSION,
|
||||
created_at_micros: input.created_at_micros,
|
||||
started_at_micros: None,
|
||||
completed_at_micros: None,
|
||||
updated_at_micros: input.created_at_micros,
|
||||
};
|
||||
|
||||
self.store.insert_task(snapshot)
|
||||
}
|
||||
|
||||
pub fn start_task(
|
||||
&self,
|
||||
task_id: &str,
|
||||
started_at_micros: i64,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
self.store.update_task(task_id, |task| {
|
||||
ensure_task_is_not_terminal(task.status)?;
|
||||
task.status = AiTaskStatus::Running;
|
||||
task.started_at_micros.get_or_insert(started_at_micros);
|
||||
task.updated_at_micros = started_at_micros;
|
||||
task.version += 1;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn start_stage(
|
||||
&self,
|
||||
task_id: &str,
|
||||
stage_kind: AiTaskStageKind,
|
||||
started_at_micros: i64,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
self.store.update_task(task_id, |task| {
|
||||
ensure_task_is_not_terminal(task.status)?;
|
||||
task.status = AiTaskStatus::Running;
|
||||
task.started_at_micros.get_or_insert(started_at_micros);
|
||||
let stage = task
|
||||
.stages
|
||||
.iter_mut()
|
||||
.find(|stage| stage.stage_kind == stage_kind)
|
||||
.ok_or(AiTaskServiceError::StageNotFound)?;
|
||||
stage.status = AiTaskStageStatus::Running;
|
||||
stage.started_at_micros.get_or_insert(started_at_micros);
|
||||
task.updated_at_micros = started_at_micros;
|
||||
task.version += 1;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn append_text_chunk(
|
||||
&self,
|
||||
task_id: &str,
|
||||
stage_kind: AiTaskStageKind,
|
||||
sequence: u32,
|
||||
delta_text: String,
|
||||
created_at_micros: i64,
|
||||
) -> Result<(AiTaskSnapshot, AiTextChunkSnapshot), AiTaskServiceError> {
|
||||
if delta_text.trim().is_empty() {
|
||||
return Err(AiTaskServiceError::Field(
|
||||
AiTaskFieldError::MissingChunkText,
|
||||
));
|
||||
}
|
||||
if sequence == 0 {
|
||||
return Err(AiTaskServiceError::Field(AiTaskFieldError::InvalidSequence));
|
||||
}
|
||||
|
||||
let chunk = AiTextChunkSnapshot {
|
||||
chunk_id: generate_ai_text_chunk_id(created_at_micros, sequence),
|
||||
task_id: normalize_required_string(task_id).unwrap_or_default(),
|
||||
stage_kind,
|
||||
sequence,
|
||||
delta_text: normalize_required_string(delta_text).unwrap_or_default(),
|
||||
created_at_micros,
|
||||
};
|
||||
|
||||
let task = self.store.append_text_chunk(chunk.clone())?;
|
||||
Ok((task, chunk))
|
||||
}
|
||||
|
||||
pub fn complete_stage(
|
||||
&self,
|
||||
input: AiStageCompletionInput,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
self.store.update_task(&input.task_id, |task| {
|
||||
ensure_task_is_not_terminal(task.status)?;
|
||||
|
||||
let stage = task
|
||||
.stages
|
||||
.iter_mut()
|
||||
.find(|stage| stage.stage_kind == input.stage_kind)
|
||||
.ok_or(AiTaskServiceError::StageNotFound)?;
|
||||
stage.status = AiTaskStageStatus::Completed;
|
||||
stage.completed_at_micros = Some(input.completed_at_micros);
|
||||
stage.text_output = normalize_optional_text(input.text_output.clone());
|
||||
stage.structured_payload_json =
|
||||
normalize_optional_text(input.structured_payload_json.clone());
|
||||
stage.warning_messages = normalize_string_list(input.warning_messages.clone());
|
||||
|
||||
task.latest_text_output = stage.text_output.clone();
|
||||
task.latest_structured_payload_json = stage.structured_payload_json.clone();
|
||||
task.updated_at_micros = input.completed_at_micros;
|
||||
task.version += 1;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn attach_result_reference(
|
||||
&self,
|
||||
task_id: &str,
|
||||
reference_kind: AiResultReferenceKind,
|
||||
reference_id: String,
|
||||
label: Option<String>,
|
||||
created_at_micros: i64,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
let Some(reference_id) = normalize_required_string(reference_id) else {
|
||||
return Err(AiTaskServiceError::Field(
|
||||
AiTaskFieldError::MissingReferenceId,
|
||||
));
|
||||
};
|
||||
|
||||
self.store.update_task(task_id, |task| {
|
||||
ensure_task_is_not_terminal(task.status)?;
|
||||
task.result_references.push(AiResultReferenceSnapshot {
|
||||
result_ref_id: generate_ai_result_ref_id(created_at_micros),
|
||||
task_id: task.task_id.clone(),
|
||||
reference_kind,
|
||||
reference_id: reference_id.clone(),
|
||||
label: normalize_optional_text(label.clone()),
|
||||
created_at_micros,
|
||||
});
|
||||
task.updated_at_micros = created_at_micros;
|
||||
task.version += 1;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn complete_task(
|
||||
&self,
|
||||
task_id: &str,
|
||||
completed_at_micros: i64,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
self.store.update_task(task_id, |task| {
|
||||
ensure_task_is_not_terminal(task.status)?;
|
||||
task.status = AiTaskStatus::Completed;
|
||||
task.completed_at_micros = Some(completed_at_micros);
|
||||
task.updated_at_micros = completed_at_micros;
|
||||
task.version += 1;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fail_task(
|
||||
&self,
|
||||
task_id: &str,
|
||||
failure_message: String,
|
||||
completed_at_micros: i64,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
let Some(failure_message) = normalize_required_string(failure_message) else {
|
||||
return Err(AiTaskServiceError::Field(
|
||||
AiTaskFieldError::MissingFailureMessage,
|
||||
));
|
||||
};
|
||||
|
||||
self.store.update_task(task_id, |task| {
|
||||
ensure_task_is_not_terminal(task.status)?;
|
||||
task.status = AiTaskStatus::Failed;
|
||||
task.failure_message = Some(failure_message.clone());
|
||||
task.completed_at_micros = Some(completed_at_micros);
|
||||
task.updated_at_micros = completed_at_micros;
|
||||
task.version += 1;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cancel_task(
|
||||
&self,
|
||||
task_id: &str,
|
||||
completed_at_micros: i64,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
self.store.update_task(task_id, |task| {
|
||||
ensure_task_is_not_terminal(task.status)?;
|
||||
task.status = AiTaskStatus::Cancelled;
|
||||
task.completed_at_micros = Some(completed_at_micros);
|
||||
task.updated_at_micros = completed_at_micros;
|
||||
task.version += 1;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_task(&self, task_id: &str) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
self.store.get_task(task_id)
|
||||
}
|
||||
}
|
||||
138
server-rs/crates/module-ai/src/application/store.rs
Normal file
138
server-rs/crates/module-ai/src/application/store.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
AiTaskServiceError, AiTaskSnapshot, AiTaskStageStatus, AiTaskStatus, AiTextChunkSnapshot,
|
||||
};
|
||||
|
||||
use super::ensure_task_is_not_terminal;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct InMemoryAiTaskStore {
|
||||
inner: Arc<Mutex<InMemoryAiTaskStoreState>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct InMemoryAiTaskStoreState {
|
||||
tasks: HashMap<String, AiTaskSnapshot>,
|
||||
text_chunks: HashMap<String, Vec<AiTextChunkSnapshot>>,
|
||||
}
|
||||
|
||||
impl InMemoryAiTaskStore {
|
||||
pub(super) fn insert_task(
|
||||
&self,
|
||||
task: AiTaskSnapshot,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
let mut state = self
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
||||
|
||||
if state.tasks.contains_key(&task.task_id) {
|
||||
return Err(AiTaskServiceError::TaskAlreadyExists);
|
||||
}
|
||||
|
||||
state.text_chunks.insert(task.task_id.clone(), Vec::new());
|
||||
state.tasks.insert(task.task_id.clone(), task.clone());
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
pub(super) fn update_task<F>(
|
||||
&self,
|
||||
task_id: &str,
|
||||
mut apply: F,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError>
|
||||
where
|
||||
F: FnMut(&mut AiTaskSnapshot) -> Result<(), AiTaskServiceError>,
|
||||
{
|
||||
let mut state = self
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
||||
let task = state
|
||||
.tasks
|
||||
.get_mut(task_id.trim())
|
||||
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
||||
apply(task)?;
|
||||
Ok(task.clone())
|
||||
}
|
||||
|
||||
pub(super) fn append_text_chunk(
|
||||
&self,
|
||||
chunk: AiTextChunkSnapshot,
|
||||
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
let mut state = self
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
||||
{
|
||||
let task = state
|
||||
.tasks
|
||||
.get_mut(&chunk.task_id)
|
||||
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
||||
ensure_task_is_not_terminal(task.status)?;
|
||||
|
||||
let stage = task
|
||||
.stages
|
||||
.iter_mut()
|
||||
.find(|stage| stage.stage_kind == chunk.stage_kind)
|
||||
.ok_or(AiTaskServiceError::StageNotFound)?;
|
||||
if stage.status == AiTaskStageStatus::Pending {
|
||||
stage.status = AiTaskStageStatus::Running;
|
||||
stage.started_at_micros = Some(chunk.created_at_micros);
|
||||
}
|
||||
|
||||
task.status = AiTaskStatus::Running;
|
||||
task.started_at_micros
|
||||
.get_or_insert(chunk.created_at_micros);
|
||||
}
|
||||
|
||||
let chunks = state
|
||||
.text_chunks
|
||||
.get_mut(&chunk.task_id)
|
||||
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
||||
chunks.push(chunk.clone());
|
||||
chunks.sort_by_key(|value| value.sequence);
|
||||
|
||||
let aggregated_text = chunks
|
||||
.iter()
|
||||
.filter(|value| value.stage_kind == chunk.stage_kind)
|
||||
.map(|value| value.delta_text.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
let normalized_output = if aggregated_text.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(aggregated_text)
|
||||
};
|
||||
|
||||
let task = state
|
||||
.tasks
|
||||
.get_mut(&chunk.task_id)
|
||||
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
||||
let stage = task
|
||||
.stages
|
||||
.iter_mut()
|
||||
.find(|stage| stage.stage_kind == chunk.stage_kind)
|
||||
.ok_or(AiTaskServiceError::StageNotFound)?;
|
||||
stage.text_output = normalized_output.clone();
|
||||
task.latest_text_output = normalized_output;
|
||||
task.updated_at_micros = chunk.created_at_micros;
|
||||
task.version += 1;
|
||||
Ok(task.clone())
|
||||
}
|
||||
|
||||
pub(super) fn get_task(&self, task_id: &str) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
||||
let state = self
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
||||
state
|
||||
.tasks
|
||||
.get(task_id.trim())
|
||||
.cloned()
|
||||
.ok_or(AiTaskServiceError::TaskNotFound)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user