Files
Genarrative/server-rs/crates/module-ai/src/application/store.rs

139 lines
4.3 KiB
Rust

use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use crate::{
AiTaskServiceError, AiTaskSnapshot, AiTaskStageStatus, AiTaskStatus, AiTextChunkSnapshot,
};
use super::ensure_task_is_not_terminal;
#[derive(Clone, Debug, Default)]
pub struct InMemoryAiTaskStore {
inner: Arc<Mutex<InMemoryAiTaskStoreState>>,
}
#[derive(Debug, Default)]
struct InMemoryAiTaskStoreState {
tasks: HashMap<String, AiTaskSnapshot>,
text_chunks: HashMap<String, Vec<AiTextChunkSnapshot>>,
}
impl InMemoryAiTaskStore {
pub(super) fn insert_task(
&self,
task: AiTaskSnapshot,
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
let mut state = self
.inner
.lock()
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
if state.tasks.contains_key(&task.task_id) {
return Err(AiTaskServiceError::TaskAlreadyExists);
}
state.text_chunks.insert(task.task_id.clone(), Vec::new());
state.tasks.insert(task.task_id.clone(), task.clone());
Ok(task)
}
pub(super) fn update_task<F>(
&self,
task_id: &str,
mut apply: F,
) -> Result<AiTaskSnapshot, AiTaskServiceError>
where
F: FnMut(&mut AiTaskSnapshot) -> Result<(), AiTaskServiceError>,
{
let mut state = self
.inner
.lock()
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
let task = state
.tasks
.get_mut(task_id.trim())
.ok_or(AiTaskServiceError::TaskNotFound)?;
apply(task)?;
Ok(task.clone())
}
pub(super) fn append_text_chunk(
&self,
chunk: AiTextChunkSnapshot,
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
let mut state = self
.inner
.lock()
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
{
let task = state
.tasks
.get_mut(&chunk.task_id)
.ok_or(AiTaskServiceError::TaskNotFound)?;
ensure_task_is_not_terminal(task.status)?;
let stage = task
.stages
.iter_mut()
.find(|stage| stage.stage_kind == chunk.stage_kind)
.ok_or(AiTaskServiceError::StageNotFound)?;
if stage.status == AiTaskStageStatus::Pending {
stage.status = AiTaskStageStatus::Running;
stage.started_at_micros = Some(chunk.created_at_micros);
}
task.status = AiTaskStatus::Running;
task.started_at_micros
.get_or_insert(chunk.created_at_micros);
}
let chunks = state
.text_chunks
.get_mut(&chunk.task_id)
.ok_or(AiTaskServiceError::TaskNotFound)?;
chunks.push(chunk.clone());
chunks.sort_by_key(|value| value.sequence);
let aggregated_text = chunks
.iter()
.filter(|value| value.stage_kind == chunk.stage_kind)
.map(|value| value.delta_text.as_str())
.collect::<Vec<_>>()
.join("");
let normalized_output = if aggregated_text.trim().is_empty() {
None
} else {
Some(aggregated_text)
};
let task = state
.tasks
.get_mut(&chunk.task_id)
.ok_or(AiTaskServiceError::TaskNotFound)?;
let stage = task
.stages
.iter_mut()
.find(|stage| stage.stage_kind == chunk.stage_kind)
.ok_or(AiTaskServiceError::StageNotFound)?;
stage.text_output = normalized_output.clone();
task.latest_text_output = normalized_output;
task.updated_at_micros = chunk.created_at_micros;
task.version += 1;
Ok(task.clone())
}
pub(super) fn get_task(&self, task_id: &str) -> Result<AiTaskSnapshot, AiTaskServiceError> {
let state = self
.inner
.lock()
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
state
.tasks
.get(task_id.trim())
.cloned()
.ok_or(AiTaskServiceError::TaskNotFound)
}
}