117 lines
2.9 KiB
Rust
117 lines
2.9 KiB
Rust
use std::{
|
|
collections::HashMap,
|
|
sync::{
|
|
Arc,
|
|
atomic::{AtomicUsize, Ordering},
|
|
},
|
|
};
|
|
|
|
use async_trait::async_trait;
|
|
use serde_json::Value;
|
|
|
|
use crate::{
|
|
callbacks::{CreativeAgentCallbackEvent, CreativeAgentCallbackKind, CreativeAgentCallbacks},
|
|
error::PlatformAgentError,
|
|
};
|
|
|
|
#[async_trait]
|
|
pub trait CreativeAgentTool: Send + Sync {
|
|
fn name(&self) -> &str;
|
|
fn description(&self) -> &str;
|
|
fn input_schema(&self) -> Option<Value> {
|
|
None
|
|
}
|
|
async fn call(&self, input: Value) -> Result<Value, PlatformAgentError>;
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct ToolExecutionBudget {
|
|
max_tool_calls: usize,
|
|
used_tool_calls: AtomicUsize,
|
|
}
|
|
|
|
impl ToolExecutionBudget {
|
|
pub fn new(max_tool_calls: usize) -> Result<Self, PlatformAgentError> {
|
|
if max_tool_calls == 0 {
|
|
return Err(PlatformAgentError::InvalidInput(
|
|
"Agent max_tool_calls 必须大于 0".to_string(),
|
|
));
|
|
}
|
|
|
|
Ok(Self {
|
|
max_tool_calls,
|
|
used_tool_calls: AtomicUsize::new(0),
|
|
})
|
|
}
|
|
|
|
pub fn reserve_call(&self) -> Result<(), PlatformAgentError> {
|
|
let previous = self.used_tool_calls.fetch_add(1, Ordering::SeqCst);
|
|
if previous >= self.max_tool_calls {
|
|
return Err(PlatformAgentError::ToolBudgetExceeded {
|
|
max_tool_calls: self.max_tool_calls,
|
|
});
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[derive(Default)]
|
|
pub struct CreativeAgentToolRegistry {
|
|
tools: HashMap<String, Arc<dyn CreativeAgentTool>>,
|
|
}
|
|
|
|
impl CreativeAgentToolRegistry {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
pub fn register<T>(&mut self, tool: T)
|
|
where
|
|
T: CreativeAgentTool + 'static,
|
|
{
|
|
self.tools.insert(tool.name().to_string(), Arc::new(tool));
|
|
}
|
|
|
|
pub fn len(&self) -> usize {
|
|
self.tools.len()
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.tools.is_empty()
|
|
}
|
|
|
|
pub async fn execute(
|
|
&self,
|
|
name: &str,
|
|
input: Value,
|
|
budget: &ToolExecutionBudget,
|
|
callbacks: &CreativeAgentCallbacks,
|
|
) -> Result<Value, PlatformAgentError> {
|
|
budget.reserve_call()?;
|
|
let tool = self
|
|
.tools
|
|
.get(name)
|
|
.ok_or_else(|| PlatformAgentError::ToolNotFound(name.to_string()))?;
|
|
|
|
callbacks.emit(CreativeAgentCallbackEvent {
|
|
kind: CreativeAgentCallbackKind::ToolStarted,
|
|
label: name.to_string(),
|
|
detail: None,
|
|
});
|
|
|
|
let result = tool.call(input).await;
|
|
callbacks.emit(CreativeAgentCallbackEvent {
|
|
kind: if result.is_ok() {
|
|
CreativeAgentCallbackKind::ToolCompleted
|
|
} else {
|
|
CreativeAgentCallbackKind::Error
|
|
},
|
|
label: name.to_string(),
|
|
detail: result.as_ref().err().map(ToString::to_string),
|
|
});
|
|
|
|
result
|
|
}
|
|
}
|