Files
Genarrative/server-rs/crates/platform-agent/src/tool_registry.rs
2026-05-08 11:44:42 +08:00

117 lines
2.9 KiB
Rust

use std::{
collections::HashMap,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use async_trait::async_trait;
use serde_json::Value;
use crate::{
callbacks::{CreativeAgentCallbackEvent, CreativeAgentCallbackKind, CreativeAgentCallbacks},
error::PlatformAgentError,
};
#[async_trait]
pub trait CreativeAgentTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn input_schema(&self) -> Option<Value> {
None
}
async fn call(&self, input: Value) -> Result<Value, PlatformAgentError>;
}
#[derive(Debug)]
pub struct ToolExecutionBudget {
max_tool_calls: usize,
used_tool_calls: AtomicUsize,
}
impl ToolExecutionBudget {
pub fn new(max_tool_calls: usize) -> Result<Self, PlatformAgentError> {
if max_tool_calls == 0 {
return Err(PlatformAgentError::InvalidInput(
"Agent max_tool_calls 必须大于 0".to_string(),
));
}
Ok(Self {
max_tool_calls,
used_tool_calls: AtomicUsize::new(0),
})
}
pub fn reserve_call(&self) -> Result<(), PlatformAgentError> {
let previous = self.used_tool_calls.fetch_add(1, Ordering::SeqCst);
if previous >= self.max_tool_calls {
return Err(PlatformAgentError::ToolBudgetExceeded {
max_tool_calls: self.max_tool_calls,
});
}
Ok(())
}
}
#[derive(Default)]
pub struct CreativeAgentToolRegistry {
tools: HashMap<String, Arc<dyn CreativeAgentTool>>,
}
impl CreativeAgentToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<T>(&mut self, tool: T)
where
T: CreativeAgentTool + 'static,
{
self.tools.insert(tool.name().to_string(), Arc::new(tool));
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub async fn execute(
&self,
name: &str,
input: Value,
budget: &ToolExecutionBudget,
callbacks: &CreativeAgentCallbacks,
) -> Result<Value, PlatformAgentError> {
budget.reserve_call()?;
let tool = self
.tools
.get(name)
.ok_or_else(|| PlatformAgentError::ToolNotFound(name.to_string()))?;
callbacks.emit(CreativeAgentCallbackEvent {
kind: CreativeAgentCallbackKind::ToolStarted,
label: name.to_string(),
detail: None,
});
let result = tool.call(input).await;
callbacks.emit(CreativeAgentCallbackEvent {
kind: if result.is_ok() {
CreativeAgentCallbackKind::ToolCompleted
} else {
CreativeAgentCallbackKind::Error
},
label: name.to_string(),
detail: result.as_ref().err().map(ToString::to_string),
});
result
}
}