use std::{ collections::HashMap, sync::{Arc, Mutex}, }; use crate::{ AiTaskServiceError, AiTaskSnapshot, AiTaskStageStatus, AiTaskStatus, AiTextChunkSnapshot, }; use super::ensure_task_is_not_terminal; #[derive(Clone, Debug, Default)] pub struct InMemoryAiTaskStore { inner: Arc>, } #[derive(Debug, Default)] struct InMemoryAiTaskStoreState { tasks: HashMap, text_chunks: HashMap>, } impl InMemoryAiTaskStore { pub(super) fn insert_task( &self, task: AiTaskSnapshot, ) -> Result { let mut state = self .inner .lock() .map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?; if state.tasks.contains_key(&task.task_id) { return Err(AiTaskServiceError::TaskAlreadyExists); } state.text_chunks.insert(task.task_id.clone(), Vec::new()); state.tasks.insert(task.task_id.clone(), task.clone()); Ok(task) } pub(super) fn update_task( &self, task_id: &str, mut apply: F, ) -> Result where F: FnMut(&mut AiTaskSnapshot) -> Result<(), AiTaskServiceError>, { let mut state = self .inner .lock() .map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?; let task = state .tasks .get_mut(task_id.trim()) .ok_or(AiTaskServiceError::TaskNotFound)?; apply(task)?; Ok(task.clone()) } pub(super) fn append_text_chunk( &self, chunk: AiTextChunkSnapshot, ) -> Result { let mut state = self .inner .lock() .map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?; { let task = state .tasks .get_mut(&chunk.task_id) .ok_or(AiTaskServiceError::TaskNotFound)?; ensure_task_is_not_terminal(task.status)?; let stage = task .stages .iter_mut() .find(|stage| stage.stage_kind == chunk.stage_kind) .ok_or(AiTaskServiceError::StageNotFound)?; if stage.status == AiTaskStageStatus::Pending { stage.status = AiTaskStageStatus::Running; stage.started_at_micros = Some(chunk.created_at_micros); } task.status = AiTaskStatus::Running; task.started_at_micros .get_or_insert(chunk.created_at_micros); } let chunks = state .text_chunks .get_mut(&chunk.task_id) .ok_or(AiTaskServiceError::TaskNotFound)?; chunks.push(chunk.clone()); chunks.sort_by_key(|value| value.sequence); let aggregated_text = chunks .iter() .filter(|value| value.stage_kind == chunk.stage_kind) .map(|value| value.delta_text.as_str()) .collect::>() .join(""); let normalized_output = if aggregated_text.trim().is_empty() { None } else { Some(aggregated_text) }; let task = state .tasks .get_mut(&chunk.task_id) .ok_or(AiTaskServiceError::TaskNotFound)?; let stage = task .stages .iter_mut() .find(|stage| stage.stage_kind == chunk.stage_kind) .ok_or(AiTaskServiceError::StageNotFound)?; stage.text_output = normalized_output.clone(); task.latest_text_output = normalized_output; task.updated_at_micros = chunk.created_at_micros; task.version += 1; Ok(task.clone()) } pub(super) fn get_task(&self, task_id: &str) -> Result { let state = self .inner .lock() .map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?; state .tasks .get(task_id.trim()) .cloned() .ok_or(AiTaskServiceError::TaskNotFound) } }