This commit is contained in:
2026-05-08 11:44:42 +08:00
parent b08127031c
commit abf1f1ebea
249 changed files with 39411 additions and 887 deletions

View File

@@ -0,0 +1,16 @@
[package]
name = "platform-agent"
edition.workspace = true
version.workspace = true
license.workspace = true
[dependencies]
async-trait = { workspace = true }
langchainrust = { workspace = true }
platform-llm = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["time"] }
[dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt", "time"] }

View File

@@ -0,0 +1,59 @@
use platform_llm::{
LlmClient, LlmMessage, LlmMessageContentPart, LlmMessageRole, LlmTextProtocol, LlmTextRequest,
LlmTextResponse,
};
use crate::error::PlatformAgentError;
pub const CREATIVE_AGENT_GPT5_MODEL: &str = "gpt-5";
#[derive(Clone)]
pub struct Gpt5ResponsesAgentClient {
llm_client: LlmClient,
}
impl Gpt5ResponsesAgentClient {
pub fn new(llm_client: LlmClient) -> Self {
Self { llm_client }
}
pub async fn request(
&self,
system_prompt: impl Into<String>,
user_text: impl Into<String>,
image_urls: Vec<String>,
) -> Result<LlmTextResponse, PlatformAgentError> {
let request = build_gpt5_multimodal_request(system_prompt, user_text, image_urls);
self.llm_client
.request_text(request)
.await
.map_err(Into::into)
}
}
pub fn build_gpt5_multimodal_request(
system_prompt: impl Into<String>,
user_text: impl Into<String>,
image_urls: Vec<String>,
) -> LlmTextRequest {
let mut user_parts = vec![LlmMessageContentPart::InputText {
text: user_text.into(),
}];
user_parts.extend(
image_urls
.into_iter()
.map(|image_url| LlmMessageContentPart::InputImage { image_url }),
);
LlmTextRequest {
model: Some(CREATIVE_AGENT_GPT5_MODEL.to_string()),
messages: vec![
LlmMessage::new(LlmMessageRole::System, system_prompt.into()),
LlmMessage::multimodal(LlmMessageRole::User, user_parts),
],
max_tokens: None,
request_timeout_ms: None,
enable_web_search: false,
protocol: LlmTextProtocol::Responses,
}
}

View File

@@ -0,0 +1,54 @@
use std::sync::Arc;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum CreativeAgentCallbackKind {
Stage,
ToolStarted,
ToolCompleted,
ModelRequestStarted,
ModelRequestCompleted,
Error,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CreativeAgentCallbackEvent {
pub kind: CreativeAgentCallbackKind,
pub label: String,
pub detail: Option<String>,
}
type CreativeAgentCallbackFn = Arc<dyn Fn(CreativeAgentCallbackEvent) + Send + Sync>;
#[derive(Clone, Default)]
pub struct CreativeAgentCallbacks {
on_event: Option<CreativeAgentCallbackFn>,
}
impl CreativeAgentCallbacks {
pub fn new<F>(on_event: F) -> Self
where
F: Fn(CreativeAgentCallbackEvent) + Send + Sync + 'static,
{
Self {
on_event: Some(Arc::new(on_event)),
}
}
pub fn noop() -> Self {
Self::default()
}
pub fn emit(&self, event: CreativeAgentCallbackEvent) {
if let Some(on_event) = &self.on_event {
on_event(event);
}
}
pub fn stage(&self, label: impl Into<String>) {
self.emit(CreativeAgentCallbackEvent {
kind: CreativeAgentCallbackKind::Stage,
label: label.into(),
detail: None,
});
}
}

View File

@@ -0,0 +1,40 @@
use std::{error::Error, fmt};
#[derive(Debug, PartialEq, Eq)]
pub enum PlatformAgentError {
InvalidInput(String),
ToolNotFound(String),
ToolExecution(String),
ToolBudgetExceeded { max_tool_calls: usize },
Timeout { timeout_ms: u64 },
Llm(String),
LangChain(String),
OutputParse(String),
}
impl fmt::Display for PlatformAgentError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidInput(message)
| Self::ToolExecution(message)
| Self::Llm(message)
| Self::LangChain(message)
| Self::OutputParse(message) => write!(f, "{message}"),
Self::ToolNotFound(name) => write!(f, "Agent 工具未注册:{name}"),
Self::ToolBudgetExceeded { max_tool_calls } => {
write!(f, "Agent 工具调用次数超过限制:{max_tool_calls}")
}
Self::Timeout { timeout_ms } => {
write!(f, "Agent 执行超时:{timeout_ms}ms")
}
}
}
}
impl Error for PlatformAgentError {}
impl From<platform_llm::LlmError> for PlatformAgentError {
fn from(error: platform_llm::LlmError) -> Self {
Self::Llm(error.to_string())
}
}

View File

@@ -0,0 +1,49 @@
use std::{future::Future, time::Duration};
use tokio::time::timeout;
use crate::error::PlatformAgentError;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FunctionAgentLimits {
pub max_tool_calls: usize,
pub timeout_ms: u64,
}
impl Default for FunctionAgentLimits {
fn default() -> Self {
Self {
max_tool_calls: 8,
timeout_ms: 30_000,
}
}
}
impl FunctionAgentLimits {
pub fn validate(&self) -> Result<(), PlatformAgentError> {
if self.max_tool_calls == 0 {
return Err(PlatformAgentError::InvalidInput(
"Agent max_tool_calls 必须大于 0".to_string(),
));
}
if self.timeout_ms == 0 {
return Err(PlatformAgentError::InvalidInput(
"Agent timeout_ms 必须大于 0".to_string(),
));
}
Ok(())
}
pub async fn run_with_timeout<F, T>(&self, future: F) -> Result<T, PlatformAgentError>
where
F: Future<Output = Result<T, PlatformAgentError>>,
{
self.validate()?;
timeout(Duration::from_millis(self.timeout_ms), future)
.await
.map_err(|_| PlatformAgentError::Timeout {
timeout_ms: self.timeout_ms,
})?
}
}

View File

@@ -0,0 +1,177 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use langchainrust::{
AgentAction, AgentError, AgentExecutor, AgentFinish, AgentOutput, AgentStep, BaseAgent,
BaseTool, FunctionCallingAgent, OpenAIChat, OpenAIConfig, ToolError, ToolInput,
};
use serde_json::json;
use crate::{
apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL, error::PlatformAgentError,
function_agent::FunctionAgentLimits,
};
pub struct LangChainRustAdapter {
limits: FunctionAgentLimits,
}
impl LangChainRustAdapter {
pub fn new(limits: FunctionAgentLimits) -> Result<Self, PlatformAgentError> {
limits.validate()?;
Ok(Self { limits })
}
pub fn limits(&self) -> &FunctionAgentLimits {
&self.limits
}
pub fn build_function_calling_agent(
&self,
api_key: impl Into<String>,
base_url: impl Into<String>,
tools: Vec<Arc<dyn BaseTool>>,
system_prompt: Option<String>,
) -> FunctionCallingAgent {
let config = OpenAIConfig::new(api_key)
.with_base_url(base_url)
.with_model(CREATIVE_AGENT_GPT5_MODEL);
let llm = OpenAIChat::new(config);
FunctionCallingAgent::new(llm, tools, system_prompt)
}
pub async fn execute_minimal_tool_call(
&self,
tool_name: impl Into<String>,
input: serde_json::Value,
) -> Result<String, PlatformAgentError> {
let tool_name = tool_name.into();
let tools: Vec<Arc<dyn BaseTool>> =
vec![Arc::new(EchoLangChainTool::new(tool_name.clone()))];
let agent = Arc::new(ScriptedLangChainToolAgent::new(tool_name, input));
let executor =
AgentExecutor::new(agent, tools).with_max_iterations(self.limits.max_tool_calls);
self.limits
.run_with_timeout(async move {
executor
.invoke("执行最小工具调用".to_string())
.await
.map_err(|error| PlatformAgentError::LangChain(error.to_string()))
})
.await
}
}
struct ScriptedLangChainToolAgent {
tool_name: String,
input: serde_json::Value,
}
impl ScriptedLangChainToolAgent {
fn new(tool_name: String, input: serde_json::Value) -> Self {
Self { tool_name, input }
}
}
#[async_trait]
impl BaseAgent for ScriptedLangChainToolAgent {
async fn plan(
&self,
intermediate_steps: &[AgentStep],
_inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
if let Some(step) = intermediate_steps.last() {
return Ok(AgentOutput::Finish(AgentFinish::new(
step.observation.clone(),
String::new(),
)));
}
Ok(AgentOutput::Action(AgentAction {
tool: self.tool_name.clone(),
tool_input: ToolInput::Object(self.input.clone()),
log: "scripted-call-1".to_string(),
}))
}
fn get_allowed_tools(&self) -> Option<Vec<&str>> {
Some(vec![self.tool_name.as_str()])
}
}
struct EchoLangChainTool {
name: String,
}
impl EchoLangChainTool {
fn new(name: String) -> Self {
Self { name }
}
}
#[async_trait]
impl BaseTool for EchoLangChainTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"用于验证 LangChain-Rust AgentExecutor 能执行最小工具调用。"
}
async fn run(&self, input: String) -> Result<String, ToolError> {
Ok(json!({
"ok": true,
"tool": self.name,
"input": input,
})
.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn langchain_adapter_executes_minimal_tool_call() {
let adapter = LangChainRustAdapter::new(FunctionAgentLimits {
max_tool_calls: 2,
timeout_ms: 1_000,
})
.expect("limits should be valid");
let output = adapter
.execute_minimal_tool_call("retrieve_puzzle_template_catalog", json!({"query": "拼图"}))
.await
.expect("langchain executor should run tool");
let parsed: serde_json::Value =
serde_json::from_str(output.as_str()).expect("tool output should be json");
assert_eq!(parsed["ok"], true);
assert_eq!(parsed["tool"], "retrieve_puzzle_template_catalog");
assert!(
parsed["input"]
.as_str()
.unwrap_or_default()
.contains("拼图")
);
}
#[test]
fn function_calling_agent_uses_gpt5_config_without_calling_network() {
let adapter = LangChainRustAdapter::new(FunctionAgentLimits::default())
.expect("limits should be valid");
let agent = adapter.build_function_calling_agent(
"test-key",
"http://127.0.0.1:9/v1",
Vec::new(),
Some("系统提示".to_string()),
);
let debug_text = format!("{agent:?}");
assert!(debug_text.contains("FunctionCallingAgent"));
assert!(debug_text.contains("系统提示"));
}
}

View File

@@ -0,0 +1,23 @@
pub mod apimart_gpt5_adapter;
pub mod callbacks;
pub mod error;
pub mod function_agent;
pub mod langchain_adapter;
pub mod output_parser;
pub mod puzzle_phase1_agent;
pub mod tool_registry;
pub use apimart_gpt5_adapter::{
CREATIVE_AGENT_GPT5_MODEL, Gpt5ResponsesAgentClient, build_gpt5_multimodal_request,
};
pub use callbacks::{
CreativeAgentCallbackEvent, CreativeAgentCallbackKind, CreativeAgentCallbacks,
};
pub use error::PlatformAgentError;
pub use function_agent::FunctionAgentLimits;
pub use langchain_adapter::LangChainRustAdapter;
pub use puzzle_phase1_agent::{
CreativeAgentExecutor, MockLangChainRustAgentExecutor, PuzzlePhase1AgentInput,
PuzzlePhase1AgentOutput,
};
pub use tool_registry::{CreativeAgentTool, CreativeAgentToolRegistry, ToolExecutionBudget};

View File

@@ -0,0 +1,12 @@
use serde::de::DeserializeOwned;
use crate::error::PlatformAgentError;
pub fn parse_json_output<T>(raw_text: &str) -> Result<T, PlatformAgentError>
where
T: DeserializeOwned,
{
serde_json::from_str(raw_text).map_err(|error| {
PlatformAgentError::OutputParse(format!("解析 Agent JSON 输出失败:{error}"))
})
}

View File

@@ -0,0 +1,112 @@
use async_trait::async_trait;
use crate::{
callbacks::CreativeAgentCallbacks, error::PlatformAgentError,
function_agent::FunctionAgentLimits,
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PuzzlePhase1AgentInput {
pub session_id: String,
pub user_text: String,
pub image_urls: Vec<String>,
pub limits: FunctionAgentLimits,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PuzzlePhase1AgentOutput {
pub session_id: String,
pub model: String,
pub final_message: String,
pub tool_call_count: usize,
}
#[async_trait]
pub trait CreativeAgentExecutor: Send + Sync {
async fn run_puzzle_phase1(
&self,
input: PuzzlePhase1AgentInput,
callbacks: CreativeAgentCallbacks,
) -> Result<PuzzlePhase1AgentOutput, PlatformAgentError>;
}
#[derive(Clone, Debug, Default)]
pub struct MockLangChainRustAgentExecutor;
#[async_trait]
impl CreativeAgentExecutor for MockLangChainRustAgentExecutor {
async fn run_puzzle_phase1(
&self,
input: PuzzlePhase1AgentInput,
callbacks: CreativeAgentCallbacks,
) -> Result<PuzzlePhase1AgentOutput, PlatformAgentError> {
input.limits.validate()?;
if input.session_id.trim().is_empty() {
return Err(PlatformAgentError::InvalidInput(
"creative session_id 不能为空".to_string(),
));
}
if input.user_text.trim().is_empty() && input.image_urls.is_empty() {
return Err(PlatformAgentError::InvalidInput(
"创意 Agent 输入文本和图片不能同时为空".to_string(),
));
}
callbacks.stage("perceiving");
callbacks.stage("thinking");
callbacks.stage("remembering");
callbacks.stage("selecting_puzzle_template");
callbacks.stage("waiting_template_confirmation");
Ok(PuzzlePhase1AgentOutput {
session_id: input.session_id,
model: crate::apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL.to_string(),
final_message: "已完成创意 Agent Phase 1 mock 执行,可进入模板确认。".to_string(),
tool_call_count: 0,
})
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use super::*;
#[tokio::test]
async fn mock_executor_keeps_gpt5_model_and_emits_stages() {
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let callbacks = CreativeAgentCallbacks::new(move |event| {
events_clone.lock().expect("events lock").push(event.label);
});
let output = MockLangChainRustAgentExecutor
.run_puzzle_phase1(
PuzzlePhase1AgentInput {
session_id: "creative-session-test".to_string(),
user_text: "做一个家庭纪念拼图".to_string(),
image_urls: vec!["https://example.com/ref.png".to_string()],
limits: FunctionAgentLimits::default(),
},
callbacks,
)
.await
.expect("mock executor should succeed");
assert_eq!(
output.model,
crate::apimart_gpt5_adapter::CREATIVE_AGENT_GPT5_MODEL
);
assert_eq!(
events.lock().expect("events lock").as_slice(),
[
"perceiving",
"thinking",
"remembering",
"selecting_puzzle_template",
"waiting_template_confirmation",
]
);
}
}

View File

@@ -0,0 +1,116 @@
use std::{
collections::HashMap,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use async_trait::async_trait;
use serde_json::Value;
use crate::{
callbacks::{CreativeAgentCallbackEvent, CreativeAgentCallbackKind, CreativeAgentCallbacks},
error::PlatformAgentError,
};
#[async_trait]
pub trait CreativeAgentTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn input_schema(&self) -> Option<Value> {
None
}
async fn call(&self, input: Value) -> Result<Value, PlatformAgentError>;
}
#[derive(Debug)]
pub struct ToolExecutionBudget {
max_tool_calls: usize,
used_tool_calls: AtomicUsize,
}
impl ToolExecutionBudget {
pub fn new(max_tool_calls: usize) -> Result<Self, PlatformAgentError> {
if max_tool_calls == 0 {
return Err(PlatformAgentError::InvalidInput(
"Agent max_tool_calls 必须大于 0".to_string(),
));
}
Ok(Self {
max_tool_calls,
used_tool_calls: AtomicUsize::new(0),
})
}
pub fn reserve_call(&self) -> Result<(), PlatformAgentError> {
let previous = self.used_tool_calls.fetch_add(1, Ordering::SeqCst);
if previous >= self.max_tool_calls {
return Err(PlatformAgentError::ToolBudgetExceeded {
max_tool_calls: self.max_tool_calls,
});
}
Ok(())
}
}
#[derive(Default)]
pub struct CreativeAgentToolRegistry {
tools: HashMap<String, Arc<dyn CreativeAgentTool>>,
}
impl CreativeAgentToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<T>(&mut self, tool: T)
where
T: CreativeAgentTool + 'static,
{
self.tools.insert(tool.name().to_string(), Arc::new(tool));
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub async fn execute(
&self,
name: &str,
input: Value,
budget: &ToolExecutionBudget,
callbacks: &CreativeAgentCallbacks,
) -> Result<Value, PlatformAgentError> {
budget.reserve_call()?;
let tool = self
.tools
.get(name)
.ok_or_else(|| PlatformAgentError::ToolNotFound(name.to_string()))?;
callbacks.emit(CreativeAgentCallbackEvent {
kind: CreativeAgentCallbackKind::ToolStarted,
label: name.to_string(),
detail: None,
});
let result = tool.call(input).await;
callbacks.emit(CreativeAgentCallbackEvent {
kind: if result.is_ok() {
CreativeAgentCallbackKind::ToolCompleted
} else {
CreativeAgentCallbackKind::Error
},
label: name.to_string(),
detail: result.as_ref().err().map(ToString::to_string),
});
result
}
}