use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use langchainrust::{ AgentAction, AgentError, AgentExecutor, AgentFinish, AgentOutput, AgentStep, BaseAgent, BaseTool, FunctionCallingAgent, OpenAIChat, OpenAIConfig, ToolError, ToolInput, }; use serde_json::json; use crate::{ apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL, error::PlatformAgentError, function_agent::FunctionAgentLimits, }; pub struct LangChainRustAdapter { limits: FunctionAgentLimits, } impl LangChainRustAdapter { pub fn new(limits: FunctionAgentLimits) -> Result { limits.validate()?; Ok(Self { limits }) } pub fn limits(&self) -> &FunctionAgentLimits { &self.limits } pub fn build_function_calling_agent( &self, api_key: impl Into, base_url: impl Into, tools: Vec>, system_prompt: Option, ) -> FunctionCallingAgent { let config = OpenAIConfig::new(api_key) .with_base_url(base_url) .with_model(CREATIVE_AGENT_GPT5_MODEL); let llm = OpenAIChat::new(config); FunctionCallingAgent::new(llm, tools, system_prompt) } pub async fn execute_minimal_tool_call( &self, tool_name: impl Into, input: serde_json::Value, ) -> Result { let tool_name = tool_name.into(); let tools: Vec> = vec![Arc::new(EchoLangChainTool::new(tool_name.clone()))]; let agent = Arc::new(ScriptedLangChainToolAgent::new(tool_name, input)); let executor = AgentExecutor::new(agent, tools).with_max_iterations(self.limits.max_tool_calls); self.limits .run_with_timeout(async move { executor .invoke("执行最小工具调用".to_string()) .await .map_err(|error| PlatformAgentError::LangChain(error.to_string())) }) .await } } struct ScriptedLangChainToolAgent { tool_name: String, input: serde_json::Value, } impl ScriptedLangChainToolAgent { fn new(tool_name: String, input: serde_json::Value) -> Self { Self { tool_name, input } } } #[async_trait] impl BaseAgent for ScriptedLangChainToolAgent { async fn plan( &self, intermediate_steps: &[AgentStep], _inputs: &HashMap, ) -> Result { if let Some(step) = intermediate_steps.last() { return Ok(AgentOutput::Finish(AgentFinish::new( step.observation.clone(), String::new(), ))); } Ok(AgentOutput::Action(AgentAction { tool: self.tool_name.clone(), tool_input: ToolInput::Object(self.input.clone()), log: "scripted-call-1".to_string(), })) } fn get_allowed_tools(&self) -> Option> { Some(vec![self.tool_name.as_str()]) } } struct EchoLangChainTool { name: String, } impl EchoLangChainTool { fn new(name: String) -> Self { Self { name } } } #[async_trait] impl BaseTool for EchoLangChainTool { fn name(&self) -> &str { &self.name } fn description(&self) -> &str { "用于验证 LangChain-Rust AgentExecutor 能执行最小工具调用。" } async fn run(&self, input: String) -> Result { Ok(json!({ "ok": true, "tool": self.name, "input": input, }) .to_string()) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn langchain_adapter_executes_minimal_tool_call() { let adapter = LangChainRustAdapter::new(FunctionAgentLimits { max_tool_calls: 2, timeout_ms: 1_000, }) .expect("limits should be valid"); let output = adapter .execute_minimal_tool_call("retrieve_puzzle_template_catalog", json!({"query": "拼图"})) .await .expect("langchain executor should run tool"); let parsed: serde_json::Value = serde_json::from_str(output.as_str()).expect("tool output should be json"); assert_eq!(parsed["ok"], true); assert_eq!(parsed["tool"], "retrieve_puzzle_template_catalog"); assert!( parsed["input"] .as_str() .unwrap_or_default() .contains("拼图") ); } #[test] fn function_calling_agent_uses_gpt5_config_without_calling_network() { let adapter = LangChainRustAdapter::new(FunctionAgentLimits::default()) .expect("limits should be valid"); let agent = adapter.build_function_calling_agent( "test-key", "http://127.0.0.1:9/v1", Vec::new(), Some("系统提示".to_string()), ); let debug_text = format!("{agent:?}"); assert!(debug_text.contains("FunctionCallingAgent")); assert!(debug_text.contains("系统提示")); } }