1
This commit is contained in:
59
server-rs/crates/platform-agent/src/apimart_gpt5_adapter.rs
Normal file
59
server-rs/crates/platform-agent/src/apimart_gpt5_adapter.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use platform_llm::{
|
||||
LlmClient, LlmMessage, LlmMessageContentPart, LlmMessageRole, LlmTextProtocol, LlmTextRequest,
|
||||
LlmTextResponse,
|
||||
};
|
||||
|
||||
use crate::error::PlatformAgentError;
|
||||
|
||||
pub const CREATIVE_AGENT_GPT5_MODEL: &str = "gpt-5";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Gpt5ResponsesAgentClient {
|
||||
llm_client: LlmClient,
|
||||
}
|
||||
|
||||
impl Gpt5ResponsesAgentClient {
|
||||
pub fn new(llm_client: LlmClient) -> Self {
|
||||
Self { llm_client }
|
||||
}
|
||||
|
||||
pub async fn request(
|
||||
&self,
|
||||
system_prompt: impl Into<String>,
|
||||
user_text: impl Into<String>,
|
||||
image_urls: Vec<String>,
|
||||
) -> Result<LlmTextResponse, PlatformAgentError> {
|
||||
let request = build_gpt5_multimodal_request(system_prompt, user_text, image_urls);
|
||||
self.llm_client
|
||||
.request_text(request)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_gpt5_multimodal_request(
|
||||
system_prompt: impl Into<String>,
|
||||
user_text: impl Into<String>,
|
||||
image_urls: Vec<String>,
|
||||
) -> LlmTextRequest {
|
||||
let mut user_parts = vec![LlmMessageContentPart::InputText {
|
||||
text: user_text.into(),
|
||||
}];
|
||||
user_parts.extend(
|
||||
image_urls
|
||||
.into_iter()
|
||||
.map(|image_url| LlmMessageContentPart::InputImage { image_url }),
|
||||
);
|
||||
|
||||
LlmTextRequest {
|
||||
model: Some(CREATIVE_AGENT_GPT5_MODEL.to_string()),
|
||||
messages: vec![
|
||||
LlmMessage::new(LlmMessageRole::System, system_prompt.into()),
|
||||
LlmMessage::multimodal(LlmMessageRole::User, user_parts),
|
||||
],
|
||||
max_tokens: None,
|
||||
request_timeout_ms: None,
|
||||
enable_web_search: false,
|
||||
protocol: LlmTextProtocol::Responses,
|
||||
}
|
||||
}
|
||||
54
server-rs/crates/platform-agent/src/callbacks.rs
Normal file
54
server-rs/crates/platform-agent/src/callbacks.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum CreativeAgentCallbackKind {
|
||||
Stage,
|
||||
ToolStarted,
|
||||
ToolCompleted,
|
||||
ModelRequestStarted,
|
||||
ModelRequestCompleted,
|
||||
Error,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct CreativeAgentCallbackEvent {
|
||||
pub kind: CreativeAgentCallbackKind,
|
||||
pub label: String,
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
type CreativeAgentCallbackFn = Arc<dyn Fn(CreativeAgentCallbackEvent) + Send + Sync>;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct CreativeAgentCallbacks {
|
||||
on_event: Option<CreativeAgentCallbackFn>,
|
||||
}
|
||||
|
||||
impl CreativeAgentCallbacks {
|
||||
pub fn new<F>(on_event: F) -> Self
|
||||
where
|
||||
F: Fn(CreativeAgentCallbackEvent) + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
on_event: Some(Arc::new(on_event)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn noop() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn emit(&self, event: CreativeAgentCallbackEvent) {
|
||||
if let Some(on_event) = &self.on_event {
|
||||
on_event(event);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stage(&self, label: impl Into<String>) {
|
||||
self.emit(CreativeAgentCallbackEvent {
|
||||
kind: CreativeAgentCallbackKind::Stage,
|
||||
label: label.into(),
|
||||
detail: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
40
server-rs/crates/platform-agent/src/error.rs
Normal file
40
server-rs/crates/platform-agent/src/error.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use std::{error::Error, fmt};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum PlatformAgentError {
|
||||
InvalidInput(String),
|
||||
ToolNotFound(String),
|
||||
ToolExecution(String),
|
||||
ToolBudgetExceeded { max_tool_calls: usize },
|
||||
Timeout { timeout_ms: u64 },
|
||||
Llm(String),
|
||||
LangChain(String),
|
||||
OutputParse(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for PlatformAgentError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::InvalidInput(message)
|
||||
| Self::ToolExecution(message)
|
||||
| Self::Llm(message)
|
||||
| Self::LangChain(message)
|
||||
| Self::OutputParse(message) => write!(f, "{message}"),
|
||||
Self::ToolNotFound(name) => write!(f, "Agent 工具未注册:{name}"),
|
||||
Self::ToolBudgetExceeded { max_tool_calls } => {
|
||||
write!(f, "Agent 工具调用次数超过限制:{max_tool_calls}")
|
||||
}
|
||||
Self::Timeout { timeout_ms } => {
|
||||
write!(f, "Agent 执行超时:{timeout_ms}ms")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for PlatformAgentError {}
|
||||
|
||||
impl From<platform_llm::LlmError> for PlatformAgentError {
|
||||
fn from(error: platform_llm::LlmError) -> Self {
|
||||
Self::Llm(error.to_string())
|
||||
}
|
||||
}
|
||||
49
server-rs/crates/platform-agent/src/function_agent.rs
Normal file
49
server-rs/crates/platform-agent/src/function_agent.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use std::{future::Future, time::Duration};
|
||||
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::error::PlatformAgentError;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct FunctionAgentLimits {
|
||||
pub max_tool_calls: usize,
|
||||
pub timeout_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for FunctionAgentLimits {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_tool_calls: 8,
|
||||
timeout_ms: 30_000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FunctionAgentLimits {
|
||||
pub fn validate(&self) -> Result<(), PlatformAgentError> {
|
||||
if self.max_tool_calls == 0 {
|
||||
return Err(PlatformAgentError::InvalidInput(
|
||||
"Agent max_tool_calls 必须大于 0".to_string(),
|
||||
));
|
||||
}
|
||||
if self.timeout_ms == 0 {
|
||||
return Err(PlatformAgentError::InvalidInput(
|
||||
"Agent timeout_ms 必须大于 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_with_timeout<F, T>(&self, future: F) -> Result<T, PlatformAgentError>
|
||||
where
|
||||
F: Future<Output = Result<T, PlatformAgentError>>,
|
||||
{
|
||||
self.validate()?;
|
||||
timeout(Duration::from_millis(self.timeout_ms), future)
|
||||
.await
|
||||
.map_err(|_| PlatformAgentError::Timeout {
|
||||
timeout_ms: self.timeout_ms,
|
||||
})?
|
||||
}
|
||||
}
|
||||
177
server-rs/crates/platform-agent/src/langchain_adapter.rs
Normal file
177
server-rs/crates/platform-agent/src/langchain_adapter.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use langchainrust::{
|
||||
AgentAction, AgentError, AgentExecutor, AgentFinish, AgentOutput, AgentStep, BaseAgent,
|
||||
BaseTool, FunctionCallingAgent, OpenAIChat, OpenAIConfig, ToolError, ToolInput,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL, error::PlatformAgentError,
|
||||
function_agent::FunctionAgentLimits,
|
||||
};
|
||||
|
||||
pub struct LangChainRustAdapter {
|
||||
limits: FunctionAgentLimits,
|
||||
}
|
||||
|
||||
impl LangChainRustAdapter {
|
||||
pub fn new(limits: FunctionAgentLimits) -> Result<Self, PlatformAgentError> {
|
||||
limits.validate()?;
|
||||
Ok(Self { limits })
|
||||
}
|
||||
|
||||
pub fn limits(&self) -> &FunctionAgentLimits {
|
||||
&self.limits
|
||||
}
|
||||
|
||||
pub fn build_function_calling_agent(
|
||||
&self,
|
||||
api_key: impl Into<String>,
|
||||
base_url: impl Into<String>,
|
||||
tools: Vec<Arc<dyn BaseTool>>,
|
||||
system_prompt: Option<String>,
|
||||
) -> FunctionCallingAgent {
|
||||
let config = OpenAIConfig::new(api_key)
|
||||
.with_base_url(base_url)
|
||||
.with_model(CREATIVE_AGENT_GPT5_MODEL);
|
||||
let llm = OpenAIChat::new(config);
|
||||
FunctionCallingAgent::new(llm, tools, system_prompt)
|
||||
}
|
||||
|
||||
pub async fn execute_minimal_tool_call(
|
||||
&self,
|
||||
tool_name: impl Into<String>,
|
||||
input: serde_json::Value,
|
||||
) -> Result<String, PlatformAgentError> {
|
||||
let tool_name = tool_name.into();
|
||||
let tools: Vec<Arc<dyn BaseTool>> =
|
||||
vec![Arc::new(EchoLangChainTool::new(tool_name.clone()))];
|
||||
let agent = Arc::new(ScriptedLangChainToolAgent::new(tool_name, input));
|
||||
let executor =
|
||||
AgentExecutor::new(agent, tools).with_max_iterations(self.limits.max_tool_calls);
|
||||
|
||||
self.limits
|
||||
.run_with_timeout(async move {
|
||||
executor
|
||||
.invoke("执行最小工具调用".to_string())
|
||||
.await
|
||||
.map_err(|error| PlatformAgentError::LangChain(error.to_string()))
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
struct ScriptedLangChainToolAgent {
|
||||
tool_name: String,
|
||||
input: serde_json::Value,
|
||||
}
|
||||
|
||||
impl ScriptedLangChainToolAgent {
|
||||
fn new(tool_name: String, input: serde_json::Value) -> Self {
|
||||
Self { tool_name, input }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BaseAgent for ScriptedLangChainToolAgent {
|
||||
async fn plan(
|
||||
&self,
|
||||
intermediate_steps: &[AgentStep],
|
||||
_inputs: &HashMap<String, String>,
|
||||
) -> Result<AgentOutput, AgentError> {
|
||||
if let Some(step) = intermediate_steps.last() {
|
||||
return Ok(AgentOutput::Finish(AgentFinish::new(
|
||||
step.observation.clone(),
|
||||
String::new(),
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(AgentOutput::Action(AgentAction {
|
||||
tool: self.tool_name.clone(),
|
||||
tool_input: ToolInput::Object(self.input.clone()),
|
||||
log: "scripted-call-1".to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn get_allowed_tools(&self) -> Option<Vec<&str>> {
|
||||
Some(vec![self.tool_name.as_str()])
|
||||
}
|
||||
}
|
||||
|
||||
struct EchoLangChainTool {
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl EchoLangChainTool {
|
||||
fn new(name: String) -> Self {
|
||||
Self { name }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for EchoLangChainTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"用于验证 LangChain-Rust AgentExecutor 能执行最小工具调用。"
|
||||
}
|
||||
|
||||
async fn run(&self, input: String) -> Result<String, ToolError> {
|
||||
Ok(json!({
|
||||
"ok": true,
|
||||
"tool": self.name,
|
||||
"input": input,
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn langchain_adapter_executes_minimal_tool_call() {
|
||||
let adapter = LangChainRustAdapter::new(FunctionAgentLimits {
|
||||
max_tool_calls: 2,
|
||||
timeout_ms: 1_000,
|
||||
})
|
||||
.expect("limits should be valid");
|
||||
|
||||
let output = adapter
|
||||
.execute_minimal_tool_call("retrieve_puzzle_template_catalog", json!({"query": "拼图"}))
|
||||
.await
|
||||
.expect("langchain executor should run tool");
|
||||
|
||||
let parsed: serde_json::Value =
|
||||
serde_json::from_str(output.as_str()).expect("tool output should be json");
|
||||
assert_eq!(parsed["ok"], true);
|
||||
assert_eq!(parsed["tool"], "retrieve_puzzle_template_catalog");
|
||||
assert!(
|
||||
parsed["input"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.contains("拼图")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn function_calling_agent_uses_gpt5_config_without_calling_network() {
|
||||
let adapter = LangChainRustAdapter::new(FunctionAgentLimits::default())
|
||||
.expect("limits should be valid");
|
||||
let agent = adapter.build_function_calling_agent(
|
||||
"test-key",
|
||||
"http://127.0.0.1:9/v1",
|
||||
Vec::new(),
|
||||
Some("系统提示".to_string()),
|
||||
);
|
||||
|
||||
let debug_text = format!("{agent:?}");
|
||||
assert!(debug_text.contains("FunctionCallingAgent"));
|
||||
assert!(debug_text.contains("系统提示"));
|
||||
}
|
||||
}
|
||||
23
server-rs/crates/platform-agent/src/lib.rs
Normal file
23
server-rs/crates/platform-agent/src/lib.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
pub mod apimart_gpt5_adapter;
|
||||
pub mod callbacks;
|
||||
pub mod error;
|
||||
pub mod function_agent;
|
||||
pub mod langchain_adapter;
|
||||
pub mod output_parser;
|
||||
pub mod puzzle_phase1_agent;
|
||||
pub mod tool_registry;
|
||||
|
||||
pub use apimart_gpt5_adapter::{
|
||||
CREATIVE_AGENT_GPT5_MODEL, Gpt5ResponsesAgentClient, build_gpt5_multimodal_request,
|
||||
};
|
||||
pub use callbacks::{
|
||||
CreativeAgentCallbackEvent, CreativeAgentCallbackKind, CreativeAgentCallbacks,
|
||||
};
|
||||
pub use error::PlatformAgentError;
|
||||
pub use function_agent::FunctionAgentLimits;
|
||||
pub use langchain_adapter::LangChainRustAdapter;
|
||||
pub use puzzle_phase1_agent::{
|
||||
CreativeAgentExecutor, MockLangChainRustAgentExecutor, PuzzlePhase1AgentInput,
|
||||
PuzzlePhase1AgentOutput,
|
||||
};
|
||||
pub use tool_registry::{CreativeAgentTool, CreativeAgentToolRegistry, ToolExecutionBudget};
|
||||
12
server-rs/crates/platform-agent/src/output_parser.rs
Normal file
12
server-rs/crates/platform-agent/src/output_parser.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use crate::error::PlatformAgentError;
|
||||
|
||||
pub fn parse_json_output<T>(raw_text: &str) -> Result<T, PlatformAgentError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
serde_json::from_str(raw_text).map_err(|error| {
|
||||
PlatformAgentError::OutputParse(format!("解析 Agent JSON 输出失败:{error}"))
|
||||
})
|
||||
}
|
||||
112
server-rs/crates/platform-agent/src/puzzle_phase1_agent.rs
Normal file
112
server-rs/crates/platform-agent/src/puzzle_phase1_agent.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
callbacks::CreativeAgentCallbacks, error::PlatformAgentError,
|
||||
function_agent::FunctionAgentLimits,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct PuzzlePhase1AgentInput {
|
||||
pub session_id: String,
|
||||
pub user_text: String,
|
||||
pub image_urls: Vec<String>,
|
||||
pub limits: FunctionAgentLimits,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct PuzzlePhase1AgentOutput {
|
||||
pub session_id: String,
|
||||
pub model: String,
|
||||
pub final_message: String,
|
||||
pub tool_call_count: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait CreativeAgentExecutor: Send + Sync {
|
||||
async fn run_puzzle_phase1(
|
||||
&self,
|
||||
input: PuzzlePhase1AgentInput,
|
||||
callbacks: CreativeAgentCallbacks,
|
||||
) -> Result<PuzzlePhase1AgentOutput, PlatformAgentError>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct MockLangChainRustAgentExecutor;
|
||||
|
||||
#[async_trait]
|
||||
impl CreativeAgentExecutor for MockLangChainRustAgentExecutor {
|
||||
async fn run_puzzle_phase1(
|
||||
&self,
|
||||
input: PuzzlePhase1AgentInput,
|
||||
callbacks: CreativeAgentCallbacks,
|
||||
) -> Result<PuzzlePhase1AgentOutput, PlatformAgentError> {
|
||||
input.limits.validate()?;
|
||||
if input.session_id.trim().is_empty() {
|
||||
return Err(PlatformAgentError::InvalidInput(
|
||||
"creative session_id 不能为空".to_string(),
|
||||
));
|
||||
}
|
||||
if input.user_text.trim().is_empty() && input.image_urls.is_empty() {
|
||||
return Err(PlatformAgentError::InvalidInput(
|
||||
"创意 Agent 输入文本和图片不能同时为空".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
callbacks.stage("perceiving");
|
||||
callbacks.stage("thinking");
|
||||
callbacks.stage("remembering");
|
||||
callbacks.stage("selecting_puzzle_template");
|
||||
callbacks.stage("waiting_template_confirmation");
|
||||
|
||||
Ok(PuzzlePhase1AgentOutput {
|
||||
session_id: input.session_id,
|
||||
model: crate::apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL.to_string(),
|
||||
final_message: "已完成创意 Agent Phase 1 mock 执行,可进入模板确认。".to_string(),
|
||||
tool_call_count: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_executor_keeps_gpt5_model_and_emits_stages() {
|
||||
let events = Arc::new(Mutex::new(Vec::new()));
|
||||
let events_clone = events.clone();
|
||||
let callbacks = CreativeAgentCallbacks::new(move |event| {
|
||||
events_clone.lock().expect("events lock").push(event.label);
|
||||
});
|
||||
|
||||
let output = MockLangChainRustAgentExecutor
|
||||
.run_puzzle_phase1(
|
||||
PuzzlePhase1AgentInput {
|
||||
session_id: "creative-session-test".to_string(),
|
||||
user_text: "做一个家庭纪念拼图".to_string(),
|
||||
image_urls: vec!["https://example.com/ref.png".to_string()],
|
||||
limits: FunctionAgentLimits::default(),
|
||||
},
|
||||
callbacks,
|
||||
)
|
||||
.await
|
||||
.expect("mock executor should succeed");
|
||||
|
||||
assert_eq!(
|
||||
output.model,
|
||||
crate::apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL
|
||||
);
|
||||
assert_eq!(
|
||||
events.lock().expect("events lock").as_slice(),
|
||||
[
|
||||
"perceiving",
|
||||
"thinking",
|
||||
"remembering",
|
||||
"selecting_puzzle_template",
|
||||
"waiting_template_confirmation",
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
116
server-rs/crates/platform-agent/src/tool_registry.rs
Normal file
116
server-rs/crates/platform-agent/src/tool_registry.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
callbacks::{CreativeAgentCallbackEvent, CreativeAgentCallbackKind, CreativeAgentCallbacks},
|
||||
error::PlatformAgentError,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait CreativeAgentTool: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn description(&self) -> &str;
|
||||
fn input_schema(&self) -> Option<Value> {
|
||||
None
|
||||
}
|
||||
async fn call(&self, input: Value) -> Result<Value, PlatformAgentError>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolExecutionBudget {
|
||||
max_tool_calls: usize,
|
||||
used_tool_calls: AtomicUsize,
|
||||
}
|
||||
|
||||
impl ToolExecutionBudget {
|
||||
pub fn new(max_tool_calls: usize) -> Result<Self, PlatformAgentError> {
|
||||
if max_tool_calls == 0 {
|
||||
return Err(PlatformAgentError::InvalidInput(
|
||||
"Agent max_tool_calls 必须大于 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
max_tool_calls,
|
||||
used_tool_calls: AtomicUsize::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reserve_call(&self) -> Result<(), PlatformAgentError> {
|
||||
let previous = self.used_tool_calls.fetch_add(1, Ordering::SeqCst);
|
||||
if previous >= self.max_tool_calls {
|
||||
return Err(PlatformAgentError::ToolBudgetExceeded {
|
||||
max_tool_calls: self.max_tool_calls,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CreativeAgentToolRegistry {
|
||||
tools: HashMap<String, Arc<dyn CreativeAgentTool>>,
|
||||
}
|
||||
|
||||
impl CreativeAgentToolRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn register<T>(&mut self, tool: T)
|
||||
where
|
||||
T: CreativeAgentTool + 'static,
|
||||
{
|
||||
self.tools.insert(tool.name().to_string(), Arc::new(tool));
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.tools.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.tools.is_empty()
|
||||
}
|
||||
|
||||
pub async fn execute(
|
||||
&self,
|
||||
name: &str,
|
||||
input: Value,
|
||||
budget: &ToolExecutionBudget,
|
||||
callbacks: &CreativeAgentCallbacks,
|
||||
) -> Result<Value, PlatformAgentError> {
|
||||
budget.reserve_call()?;
|
||||
let tool = self
|
||||
.tools
|
||||
.get(name)
|
||||
.ok_or_else(|| PlatformAgentError::ToolNotFound(name.to_string()))?;
|
||||
|
||||
callbacks.emit(CreativeAgentCallbackEvent {
|
||||
kind: CreativeAgentCallbackKind::ToolStarted,
|
||||
label: name.to_string(),
|
||||
detail: None,
|
||||
});
|
||||
|
||||
let result = tool.call(input).await;
|
||||
callbacks.emit(CreativeAgentCallbackEvent {
|
||||
kind: if result.is_ok() {
|
||||
CreativeAgentCallbackKind::ToolCompleted
|
||||
} else {
|
||||
CreativeAgentCallbackKind::Error
|
||||
},
|
||||
label: name.to_string(),
|
||||
detail: result.as_ref().err().map(ToString::to_string),
|
||||
});
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user