139 lines
4.3 KiB
Rust
139 lines
4.3 KiB
Rust
use std::{
|
|
collections::HashMap,
|
|
sync::{Arc, Mutex},
|
|
};
|
|
|
|
use crate::{
|
|
AiTaskServiceError, AiTaskSnapshot, AiTaskStageStatus, AiTaskStatus, AiTextChunkSnapshot,
|
|
};
|
|
|
|
use super::ensure_task_is_not_terminal;
|
|
|
|
#[derive(Clone, Debug, Default)]
|
|
pub struct InMemoryAiTaskStore {
|
|
inner: Arc<Mutex<InMemoryAiTaskStoreState>>,
|
|
}
|
|
|
|
#[derive(Debug, Default)]
|
|
struct InMemoryAiTaskStoreState {
|
|
tasks: HashMap<String, AiTaskSnapshot>,
|
|
text_chunks: HashMap<String, Vec<AiTextChunkSnapshot>>,
|
|
}
|
|
|
|
impl InMemoryAiTaskStore {
|
|
pub(super) fn insert_task(
|
|
&self,
|
|
task: AiTaskSnapshot,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
let mut state = self
|
|
.inner
|
|
.lock()
|
|
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
|
|
|
if state.tasks.contains_key(&task.task_id) {
|
|
return Err(AiTaskServiceError::TaskAlreadyExists);
|
|
}
|
|
|
|
state.text_chunks.insert(task.task_id.clone(), Vec::new());
|
|
state.tasks.insert(task.task_id.clone(), task.clone());
|
|
Ok(task)
|
|
}
|
|
|
|
pub(super) fn update_task<F>(
|
|
&self,
|
|
task_id: &str,
|
|
mut apply: F,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError>
|
|
where
|
|
F: FnMut(&mut AiTaskSnapshot) -> Result<(), AiTaskServiceError>,
|
|
{
|
|
let mut state = self
|
|
.inner
|
|
.lock()
|
|
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
|
let task = state
|
|
.tasks
|
|
.get_mut(task_id.trim())
|
|
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
|
apply(task)?;
|
|
Ok(task.clone())
|
|
}
|
|
|
|
pub(super) fn append_text_chunk(
|
|
&self,
|
|
chunk: AiTextChunkSnapshot,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
let mut state = self
|
|
.inner
|
|
.lock()
|
|
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
|
{
|
|
let task = state
|
|
.tasks
|
|
.get_mut(&chunk.task_id)
|
|
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
|
ensure_task_is_not_terminal(task.status)?;
|
|
|
|
let stage = task
|
|
.stages
|
|
.iter_mut()
|
|
.find(|stage| stage.stage_kind == chunk.stage_kind)
|
|
.ok_or(AiTaskServiceError::StageNotFound)?;
|
|
if stage.status == AiTaskStageStatus::Pending {
|
|
stage.status = AiTaskStageStatus::Running;
|
|
stage.started_at_micros = Some(chunk.created_at_micros);
|
|
}
|
|
|
|
task.status = AiTaskStatus::Running;
|
|
task.started_at_micros
|
|
.get_or_insert(chunk.created_at_micros);
|
|
}
|
|
|
|
let chunks = state
|
|
.text_chunks
|
|
.get_mut(&chunk.task_id)
|
|
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
|
chunks.push(chunk.clone());
|
|
chunks.sort_by_key(|value| value.sequence);
|
|
|
|
let aggregated_text = chunks
|
|
.iter()
|
|
.filter(|value| value.stage_kind == chunk.stage_kind)
|
|
.map(|value| value.delta_text.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join("");
|
|
let normalized_output = if aggregated_text.trim().is_empty() {
|
|
None
|
|
} else {
|
|
Some(aggregated_text)
|
|
};
|
|
|
|
let task = state
|
|
.tasks
|
|
.get_mut(&chunk.task_id)
|
|
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
|
let stage = task
|
|
.stages
|
|
.iter_mut()
|
|
.find(|stage| stage.stage_kind == chunk.stage_kind)
|
|
.ok_or(AiTaskServiceError::StageNotFound)?;
|
|
stage.text_output = normalized_output.clone();
|
|
task.latest_text_output = normalized_output;
|
|
task.updated_at_micros = chunk.created_at_micros;
|
|
task.version += 1;
|
|
Ok(task.clone())
|
|
}
|
|
|
|
pub(super) fn get_task(&self, task_id: &str) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
let state = self
|
|
.inner
|
|
.lock()
|
|
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
|
state
|
|
.tasks
|
|
.get(task_id.trim())
|
|
.cloned()
|
|
.ok_or(AiTaskServiceError::TaskNotFound)
|
|
}
|
|
}
|