1
This commit is contained in:
177
server-rs/crates/platform-agent/src/langchain_adapter.rs
Normal file
177
server-rs/crates/platform-agent/src/langchain_adapter.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use langchainrust::{
|
||||
AgentAction, AgentError, AgentExecutor, AgentFinish, AgentOutput, AgentStep, BaseAgent,
|
||||
BaseTool, FunctionCallingAgent, OpenAIChat, OpenAIConfig, ToolError, ToolInput,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL, error::PlatformAgentError,
|
||||
function_agent::FunctionAgentLimits,
|
||||
};
|
||||
|
||||
pub struct LangChainRustAdapter {
|
||||
limits: FunctionAgentLimits,
|
||||
}
|
||||
|
||||
impl LangChainRustAdapter {
|
||||
pub fn new(limits: FunctionAgentLimits) -> Result<Self, PlatformAgentError> {
|
||||
limits.validate()?;
|
||||
Ok(Self { limits })
|
||||
}
|
||||
|
||||
pub fn limits(&self) -> &FunctionAgentLimits {
|
||||
&self.limits
|
||||
}
|
||||
|
||||
pub fn build_function_calling_agent(
|
||||
&self,
|
||||
api_key: impl Into<String>,
|
||||
base_url: impl Into<String>,
|
||||
tools: Vec<Arc<dyn BaseTool>>,
|
||||
system_prompt: Option<String>,
|
||||
) -> FunctionCallingAgent {
|
||||
let config = OpenAIConfig::new(api_key)
|
||||
.with_base_url(base_url)
|
||||
.with_model(CREATIVE_AGENT_GPT5_MODEL);
|
||||
let llm = OpenAIChat::new(config);
|
||||
FunctionCallingAgent::new(llm, tools, system_prompt)
|
||||
}
|
||||
|
||||
pub async fn execute_minimal_tool_call(
|
||||
&self,
|
||||
tool_name: impl Into<String>,
|
||||
input: serde_json::Value,
|
||||
) -> Result<String, PlatformAgentError> {
|
||||
let tool_name = tool_name.into();
|
||||
let tools: Vec<Arc<dyn BaseTool>> =
|
||||
vec![Arc::new(EchoLangChainTool::new(tool_name.clone()))];
|
||||
let agent = Arc::new(ScriptedLangChainToolAgent::new(tool_name, input));
|
||||
let executor =
|
||||
AgentExecutor::new(agent, tools).with_max_iterations(self.limits.max_tool_calls);
|
||||
|
||||
self.limits
|
||||
.run_with_timeout(async move {
|
||||
executor
|
||||
.invoke("执行最小工具调用".to_string())
|
||||
.await
|
||||
.map_err(|error| PlatformAgentError::LangChain(error.to_string()))
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
struct ScriptedLangChainToolAgent {
|
||||
tool_name: String,
|
||||
input: serde_json::Value,
|
||||
}
|
||||
|
||||
impl ScriptedLangChainToolAgent {
|
||||
fn new(tool_name: String, input: serde_json::Value) -> Self {
|
||||
Self { tool_name, input }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BaseAgent for ScriptedLangChainToolAgent {
|
||||
async fn plan(
|
||||
&self,
|
||||
intermediate_steps: &[AgentStep],
|
||||
_inputs: &HashMap<String, String>,
|
||||
) -> Result<AgentOutput, AgentError> {
|
||||
if let Some(step) = intermediate_steps.last() {
|
||||
return Ok(AgentOutput::Finish(AgentFinish::new(
|
||||
step.observation.clone(),
|
||||
String::new(),
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(AgentOutput::Action(AgentAction {
|
||||
tool: self.tool_name.clone(),
|
||||
tool_input: ToolInput::Object(self.input.clone()),
|
||||
log: "scripted-call-1".to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn get_allowed_tools(&self) -> Option<Vec<&str>> {
|
||||
Some(vec![self.tool_name.as_str()])
|
||||
}
|
||||
}
|
||||
|
||||
struct EchoLangChainTool {
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl EchoLangChainTool {
|
||||
fn new(name: String) -> Self {
|
||||
Self { name }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for EchoLangChainTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"用于验证 LangChain-Rust AgentExecutor 能执行最小工具调用。"
|
||||
}
|
||||
|
||||
async fn run(&self, input: String) -> Result<String, ToolError> {
|
||||
Ok(json!({
|
||||
"ok": true,
|
||||
"tool": self.name,
|
||||
"input": input,
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn langchain_adapter_executes_minimal_tool_call() {
|
||||
let adapter = LangChainRustAdapter::new(FunctionAgentLimits {
|
||||
max_tool_calls: 2,
|
||||
timeout_ms: 1_000,
|
||||
})
|
||||
.expect("limits should be valid");
|
||||
|
||||
let output = adapter
|
||||
.execute_minimal_tool_call("retrieve_puzzle_template_catalog", json!({"query": "拼图"}))
|
||||
.await
|
||||
.expect("langchain executor should run tool");
|
||||
|
||||
let parsed: serde_json::Value =
|
||||
serde_json::from_str(output.as_str()).expect("tool output should be json");
|
||||
assert_eq!(parsed["ok"], true);
|
||||
assert_eq!(parsed["tool"], "retrieve_puzzle_template_catalog");
|
||||
assert!(
|
||||
parsed["input"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.contains("拼图")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn function_calling_agent_uses_gpt5_config_without_calling_network() {
|
||||
let adapter = LangChainRustAdapter::new(FunctionAgentLimits::default())
|
||||
.expect("limits should be valid");
|
||||
let agent = adapter.build_function_calling_agent(
|
||||
"test-key",
|
||||
"http://127.0.0.1:9/v1",
|
||||
Vec::new(),
|
||||
Some("系统提示".to_string()),
|
||||
);
|
||||
|
||||
let debug_text = format!("{agent:?}");
|
||||
assert!(debug_text.contains("FunctionCallingAgent"));
|
||||
assert!(debug_text.contains("系统提示"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user