1
This commit is contained in:
116
server-rs/crates/platform-agent/src/tool_registry.rs
Normal file
116
server-rs/crates/platform-agent/src/tool_registry.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
callbacks::{CreativeAgentCallbackEvent, CreativeAgentCallbackKind, CreativeAgentCallbacks},
|
||||
error::PlatformAgentError,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait CreativeAgentTool: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn description(&self) -> &str;
|
||||
fn input_schema(&self) -> Option<Value> {
|
||||
None
|
||||
}
|
||||
async fn call(&self, input: Value) -> Result<Value, PlatformAgentError>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolExecutionBudget {
|
||||
max_tool_calls: usize,
|
||||
used_tool_calls: AtomicUsize,
|
||||
}
|
||||
|
||||
impl ToolExecutionBudget {
|
||||
pub fn new(max_tool_calls: usize) -> Result<Self, PlatformAgentError> {
|
||||
if max_tool_calls == 0 {
|
||||
return Err(PlatformAgentError::InvalidInput(
|
||||
"Agent max_tool_calls 必须大于 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
max_tool_calls,
|
||||
used_tool_calls: AtomicUsize::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reserve_call(&self) -> Result<(), PlatformAgentError> {
|
||||
let previous = self.used_tool_calls.fetch_add(1, Ordering::SeqCst);
|
||||
if previous >= self.max_tool_calls {
|
||||
return Err(PlatformAgentError::ToolBudgetExceeded {
|
||||
max_tool_calls: self.max_tool_calls,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CreativeAgentToolRegistry {
|
||||
tools: HashMap<String, Arc<dyn CreativeAgentTool>>,
|
||||
}
|
||||
|
||||
impl CreativeAgentToolRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn register<T>(&mut self, tool: T)
|
||||
where
|
||||
T: CreativeAgentTool + 'static,
|
||||
{
|
||||
self.tools.insert(tool.name().to_string(), Arc::new(tool));
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.tools.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.tools.is_empty()
|
||||
}
|
||||
|
||||
pub async fn execute(
|
||||
&self,
|
||||
name: &str,
|
||||
input: Value,
|
||||
budget: &ToolExecutionBudget,
|
||||
callbacks: &CreativeAgentCallbacks,
|
||||
) -> Result<Value, PlatformAgentError> {
|
||||
budget.reserve_call()?;
|
||||
let tool = self
|
||||
.tools
|
||||
.get(name)
|
||||
.ok_or_else(|| PlatformAgentError::ToolNotFound(name.to_string()))?;
|
||||
|
||||
callbacks.emit(CreativeAgentCallbackEvent {
|
||||
kind: CreativeAgentCallbackKind::ToolStarted,
|
||||
label: name.to_string(),
|
||||
detail: None,
|
||||
});
|
||||
|
||||
let result = tool.call(input).await;
|
||||
callbacks.emit(CreativeAgentCallbackEvent {
|
||||
kind: if result.is_ok() {
|
||||
CreativeAgentCallbackKind::ToolCompleted
|
||||
} else {
|
||||
CreativeAgentCallbackKind::Error
|
||||
},
|
||||
label: name.to_string(),
|
||||
detail: result.as_ref().err().map(ToString::to_string),
|
||||
});
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user