178 lines
5.1 KiB
Rust
178 lines
5.1 KiB
Rust
use std::{collections::HashMap, sync::Arc};
|
|
|
|
use async_trait::async_trait;
|
|
use langchainrust::{
|
|
AgentAction, AgentError, AgentExecutor, AgentFinish, AgentOutput, AgentStep, BaseAgent,
|
|
BaseTool, FunctionCallingAgent, OpenAIChat, OpenAIConfig, ToolError, ToolInput,
|
|
};
|
|
use serde_json::json;
|
|
|
|
use crate::{
|
|
apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL, error::PlatformAgentError,
|
|
function_agent::FunctionAgentLimits,
|
|
};
|
|
|
|
pub struct LangChainRustAdapter {
|
|
limits: FunctionAgentLimits,
|
|
}
|
|
|
|
impl LangChainRustAdapter {
|
|
pub fn new(limits: FunctionAgentLimits) -> Result<Self, PlatformAgentError> {
|
|
limits.validate()?;
|
|
Ok(Self { limits })
|
|
}
|
|
|
|
pub fn limits(&self) -> &FunctionAgentLimits {
|
|
&self.limits
|
|
}
|
|
|
|
pub fn build_function_calling_agent(
|
|
&self,
|
|
api_key: impl Into<String>,
|
|
base_url: impl Into<String>,
|
|
tools: Vec<Arc<dyn BaseTool>>,
|
|
system_prompt: Option<String>,
|
|
) -> FunctionCallingAgent {
|
|
let config = OpenAIConfig::new(api_key)
|
|
.with_base_url(base_url)
|
|
.with_model(CREATIVE_AGENT_GPT5_MODEL);
|
|
let llm = OpenAIChat::new(config);
|
|
FunctionCallingAgent::new(llm, tools, system_prompt)
|
|
}
|
|
|
|
pub async fn execute_minimal_tool_call(
|
|
&self,
|
|
tool_name: impl Into<String>,
|
|
input: serde_json::Value,
|
|
) -> Result<String, PlatformAgentError> {
|
|
let tool_name = tool_name.into();
|
|
let tools: Vec<Arc<dyn BaseTool>> =
|
|
vec![Arc::new(EchoLangChainTool::new(tool_name.clone()))];
|
|
let agent = Arc::new(ScriptedLangChainToolAgent::new(tool_name, input));
|
|
let executor =
|
|
AgentExecutor::new(agent, tools).with_max_iterations(self.limits.max_tool_calls);
|
|
|
|
self.limits
|
|
.run_with_timeout(async move {
|
|
executor
|
|
.invoke("执行最小工具调用".to_string())
|
|
.await
|
|
.map_err(|error| PlatformAgentError::LangChain(error.to_string()))
|
|
})
|
|
.await
|
|
}
|
|
}
|
|
|
|
struct ScriptedLangChainToolAgent {
|
|
tool_name: String,
|
|
input: serde_json::Value,
|
|
}
|
|
|
|
impl ScriptedLangChainToolAgent {
|
|
fn new(tool_name: String, input: serde_json::Value) -> Self {
|
|
Self { tool_name, input }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl BaseAgent for ScriptedLangChainToolAgent {
|
|
async fn plan(
|
|
&self,
|
|
intermediate_steps: &[AgentStep],
|
|
_inputs: &HashMap<String, String>,
|
|
) -> Result<AgentOutput, AgentError> {
|
|
if let Some(step) = intermediate_steps.last() {
|
|
return Ok(AgentOutput::Finish(AgentFinish::new(
|
|
step.observation.clone(),
|
|
String::new(),
|
|
)));
|
|
}
|
|
|
|
Ok(AgentOutput::Action(AgentAction {
|
|
tool: self.tool_name.clone(),
|
|
tool_input: ToolInput::Object(self.input.clone()),
|
|
log: "scripted-call-1".to_string(),
|
|
}))
|
|
}
|
|
|
|
fn get_allowed_tools(&self) -> Option<Vec<&str>> {
|
|
Some(vec![self.tool_name.as_str()])
|
|
}
|
|
}
|
|
|
|
struct EchoLangChainTool {
|
|
name: String,
|
|
}
|
|
|
|
impl EchoLangChainTool {
|
|
fn new(name: String) -> Self {
|
|
Self { name }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl BaseTool for EchoLangChainTool {
|
|
fn name(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
fn description(&self) -> &str {
|
|
"用于验证 LangChain-Rust AgentExecutor 能执行最小工具调用。"
|
|
}
|
|
|
|
async fn run(&self, input: String) -> Result<String, ToolError> {
|
|
Ok(json!({
|
|
"ok": true,
|
|
"tool": self.name,
|
|
"input": input,
|
|
})
|
|
.to_string())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn langchain_adapter_executes_minimal_tool_call() {
|
|
let adapter = LangChainRustAdapter::new(FunctionAgentLimits {
|
|
max_tool_calls: 2,
|
|
timeout_ms: 1_000,
|
|
})
|
|
.expect("limits should be valid");
|
|
|
|
let output = adapter
|
|
.execute_minimal_tool_call("retrieve_puzzle_template_catalog", json!({"query": "拼图"}))
|
|
.await
|
|
.expect("langchain executor should run tool");
|
|
|
|
let parsed: serde_json::Value =
|
|
serde_json::from_str(output.as_str()).expect("tool output should be json");
|
|
assert_eq!(parsed["ok"], true);
|
|
assert_eq!(parsed["tool"], "retrieve_puzzle_template_catalog");
|
|
assert!(
|
|
parsed["input"]
|
|
.as_str()
|
|
.unwrap_or_default()
|
|
.contains("拼图")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn function_calling_agent_uses_gpt5_config_without_calling_network() {
|
|
let adapter = LangChainRustAdapter::new(FunctionAgentLimits::default())
|
|
.expect("limits should be valid");
|
|
let agent = adapter.build_function_calling_agent(
|
|
"test-key",
|
|
"http://127.0.0.1:9/v1",
|
|
Vec::new(),
|
|
Some("系统提示".to_string()),
|
|
);
|
|
|
|
let debug_text = format!("{agent:?}");
|
|
assert!(debug_text.contains("FunctionCallingAgent"));
|
|
assert!(debug_text.contains("系统提示"));
|
|
}
|
|
}
|