Files
Genarrative/server-rs/crates/platform-agent/src/langchain_adapter.rs
2026-05-08 11:44:42 +08:00

178 lines
5.1 KiB
Rust

use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use langchainrust::{
AgentAction, AgentError, AgentExecutor, AgentFinish, AgentOutput, AgentStep, BaseAgent,
BaseTool, FunctionCallingAgent, OpenAIChat, OpenAIConfig, ToolError, ToolInput,
};
use serde_json::json;
use crate::{
apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL, error::PlatformAgentError,
function_agent::FunctionAgentLimits,
};
pub struct LangChainRustAdapter {
limits: FunctionAgentLimits,
}
impl LangChainRustAdapter {
pub fn new(limits: FunctionAgentLimits) -> Result<Self, PlatformAgentError> {
limits.validate()?;
Ok(Self { limits })
}
pub fn limits(&self) -> &FunctionAgentLimits {
&self.limits
}
pub fn build_function_calling_agent(
&self,
api_key: impl Into<String>,
base_url: impl Into<String>,
tools: Vec<Arc<dyn BaseTool>>,
system_prompt: Option<String>,
) -> FunctionCallingAgent {
let config = OpenAIConfig::new(api_key)
.with_base_url(base_url)
.with_model(CREATIVE_AGENT_GPT5_MODEL);
let llm = OpenAIChat::new(config);
FunctionCallingAgent::new(llm, tools, system_prompt)
}
pub async fn execute_minimal_tool_call(
&self,
tool_name: impl Into<String>,
input: serde_json::Value,
) -> Result<String, PlatformAgentError> {
let tool_name = tool_name.into();
let tools: Vec<Arc<dyn BaseTool>> =
vec![Arc::new(EchoLangChainTool::new(tool_name.clone()))];
let agent = Arc::new(ScriptedLangChainToolAgent::new(tool_name, input));
let executor =
AgentExecutor::new(agent, tools).with_max_iterations(self.limits.max_tool_calls);
self.limits
.run_with_timeout(async move {
executor
.invoke("执行最小工具调用".to_string())
.await
.map_err(|error| PlatformAgentError::LangChain(error.to_string()))
})
.await
}
}
struct ScriptedLangChainToolAgent {
tool_name: String,
input: serde_json::Value,
}
impl ScriptedLangChainToolAgent {
fn new(tool_name: String, input: serde_json::Value) -> Self {
Self { tool_name, input }
}
}
#[async_trait]
impl BaseAgent for ScriptedLangChainToolAgent {
async fn plan(
&self,
intermediate_steps: &[AgentStep],
_inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
if let Some(step) = intermediate_steps.last() {
return Ok(AgentOutput::Finish(AgentFinish::new(
step.observation.clone(),
String::new(),
)));
}
Ok(AgentOutput::Action(AgentAction {
tool: self.tool_name.clone(),
tool_input: ToolInput::Object(self.input.clone()),
log: "scripted-call-1".to_string(),
}))
}
fn get_allowed_tools(&self) -> Option<Vec<&str>> {
Some(vec![self.tool_name.as_str()])
}
}
struct EchoLangChainTool {
name: String,
}
impl EchoLangChainTool {
fn new(name: String) -> Self {
Self { name }
}
}
#[async_trait]
impl BaseTool for EchoLangChainTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"用于验证 LangChain-Rust AgentExecutor 能执行最小工具调用。"
}
async fn run(&self, input: String) -> Result<String, ToolError> {
Ok(json!({
"ok": true,
"tool": self.name,
"input": input,
})
.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn langchain_adapter_executes_minimal_tool_call() {
let adapter = LangChainRustAdapter::new(FunctionAgentLimits {
max_tool_calls: 2,
timeout_ms: 1_000,
})
.expect("limits should be valid");
let output = adapter
.execute_minimal_tool_call("retrieve_puzzle_template_catalog", json!({"query": "拼图"}))
.await
.expect("langchain executor should run tool");
let parsed: serde_json::Value =
serde_json::from_str(output.as_str()).expect("tool output should be json");
assert_eq!(parsed["ok"], true);
assert_eq!(parsed["tool"], "retrieve_puzzle_template_catalog");
assert!(
parsed["input"]
.as_str()
.unwrap_or_default()
.contains("拼图")
);
}
#[test]
fn function_calling_agent_uses_gpt5_config_without_calling_network() {
let adapter = LangChainRustAdapter::new(FunctionAgentLimits::default())
.expect("limits should be valid");
let agent = adapter.build_function_calling_agent(
"test-key",
"http://127.0.0.1:9/v1",
Vec::new(),
Some("系统提示".to_string()),
);
let debug_text = format!("{agent:?}");
assert!(debug_text.contains("FunctionCallingAgent"));
assert!(debug_text.contains("系统提示"));
}
}