use std::{ collections::HashMap, sync::{ Arc, atomic::{AtomicUsize, Ordering}, }, }; use async_trait::async_trait; use serde_json::Value; use crate::{ callbacks::{CreativeAgentCallbackEvent, CreativeAgentCallbackKind, CreativeAgentCallbacks}, error::PlatformAgentError, }; #[async_trait] pub trait CreativeAgentTool: Send + Sync { fn name(&self) -> &str; fn description(&self) -> &str; fn input_schema(&self) -> Option { None } async fn call(&self, input: Value) -> Result; } #[derive(Debug)] pub struct ToolExecutionBudget { max_tool_calls: usize, used_tool_calls: AtomicUsize, } impl ToolExecutionBudget { pub fn new(max_tool_calls: usize) -> Result { if max_tool_calls == 0 { return Err(PlatformAgentError::InvalidInput( "Agent max_tool_calls 必须大于 0".to_string(), )); } Ok(Self { max_tool_calls, used_tool_calls: AtomicUsize::new(0), }) } pub fn reserve_call(&self) -> Result<(), PlatformAgentError> { let previous = self.used_tool_calls.fetch_add(1, Ordering::SeqCst); if previous >= self.max_tool_calls { return Err(PlatformAgentError::ToolBudgetExceeded { max_tool_calls: self.max_tool_calls, }); } Ok(()) } } #[derive(Default)] pub struct CreativeAgentToolRegistry { tools: HashMap>, } impl CreativeAgentToolRegistry { pub fn new() -> Self { Self::default() } pub fn register(&mut self, tool: T) where T: CreativeAgentTool + 'static, { self.tools.insert(tool.name().to_string(), Arc::new(tool)); } pub fn len(&self) -> usize { self.tools.len() } pub fn is_empty(&self) -> bool { self.tools.is_empty() } pub async fn execute( &self, name: &str, input: Value, budget: &ToolExecutionBudget, callbacks: &CreativeAgentCallbacks, ) -> Result { budget.reserve_call()?; let tool = self .tools .get(name) .ok_or_else(|| PlatformAgentError::ToolNotFound(name.to_string()))?; callbacks.emit(CreativeAgentCallbackEvent { kind: CreativeAgentCallbackKind::ToolStarted, label: name.to_string(), detail: None, }); let result = tool.call(input).await; callbacks.emit(CreativeAgentCallbackEvent { kind: if result.is_ok() { CreativeAgentCallbackKind::ToolCompleted } else { CreativeAgentCallbackKind::Error }, label: name.to_string(), detail: result.as_ref().err().map(ToString::to_string), }); result } }