Files
Genarrative/server-rs/crates/api-server/src/main.rs

274 lines
7.8 KiB
Rust

#![recursion_limit = "256"]
mod admin;
mod ai_generation_drafts;
mod ai_tasks;
mod api_response;
mod app;
mod asset_billing;
mod assets;
mod auth;
mod auth_me;
mod auth_payload;
mod auth_public_user;
mod auth_session;
mod auth_sessions;
mod big_fish;
mod big_fish_agent_turn;
mod big_fish_draft_compiler;
mod character_animation_assets;
mod character_visual_assets;
mod config;
mod creation_agent_anchor_templates;
mod creation_agent_chat;
mod creation_agent_document_input;
mod creation_agent_llm_turn;
mod creation_entry_config;
mod creative_agent;
mod creative_agent_sse;
mod custom_world;
mod custom_world_agent_entities;
mod custom_world_agent_turn;
mod custom_world_ai;
mod custom_world_asset_prompts;
mod custom_world_foundation_draft;
mod custom_world_result_prompts;
mod custom_world_rpg_draft_prompts;
mod error_middleware;
mod health;
mod http_error;
mod hyper3d_generation;
mod llm;
mod llm_model_routing;
mod login_options;
mod logout;
mod logout_all;
mod match3d;
mod openai_image_generation;
mod password_entry;
mod password_management;
mod phone_auth;
mod platform_errors;
mod profile_identity;
mod prompt;
mod puzzle;
mod puzzle_agent_turn;
mod refresh_session;
mod registration_reward;
mod request_context;
mod response_headers;
mod runtime_browse_history;
mod runtime_chat;
mod runtime_chat_plain;
mod runtime_inventory;
mod runtime_profile;
mod runtime_save;
mod runtime_settings;
mod session_client;
mod square_hole;
mod square_hole_agent_turn;
mod state;
mod story_battles;
mod story_sessions;
mod tracking;
mod vector_engine_audio_generation;
mod visual_novel;
mod volcengine_speech;
mod wechat_auth;
mod wechat_pay;
mod wechat_provider;
mod work_author;
mod work_play_tracking;
use shared_logging::init_tracing;
use std::{collections::HashSet, env, fs, io, panic, thread, time::Duration};
use tokio::net::TcpListener;
use tokio::runtime::Builder as TokioRuntimeBuilder;
use tokio::time::timeout;
use tracing::{info, warn};
use crate::{app::build_router, config::AppConfig, state::AppState};
const API_SERVER_STARTUP_STACK_SIZE_BYTES: usize = 32 * 1024 * 1024;
const AUTH_STORE_STARTUP_RESTORE_TIMEOUT: Duration = Duration::from_secs(8);
fn main() -> Result<(), io::Error> {
// Windows 本地调试下 Axum 路由树和启动恢复链较重,显式放大启动线程栈,避免 debug 构建在进入监听前栈溢出。
let server_thread = thread::Builder::new()
.name("api-server-bootstrap".to_string())
.stack_size(API_SERVER_STARTUP_STACK_SIZE_BYTES)
.spawn(|| {
TokioRuntimeBuilder::new_multi_thread()
.enable_all()
.thread_name("api-server-worker")
.thread_stack_size(API_SERVER_STARTUP_STACK_SIZE_BYTES)
.build()?
.block_on(run_server())
})?;
match server_thread.join() {
Ok(result) => result,
Err(payload) => panic::resume_unwind(payload),
}
}
async fn run_server() -> Result<(), io::Error> {
// 运行本地开发与联调时,优先从仓库根目录加载本地变量。
// 只尊重外层 shell 先注入的变量;后续本地文件需要能覆盖前序本地文件。
load_local_env_files();
// 统一先从配置对象读取监听地址,避免后续把环境变量读取散落到入口和路由层。
let config = AppConfig::from_env();
init_tracing(&config.log_filter)?;
let bind_address = config.bind_socket_addr();
let listener = TcpListener::bind(bind_address).await?;
let state = restore_app_state_for_startup(config)
.await
.map_err(|error| std::io::Error::other(format!("初始化应用状态失败:{error}")))?;
let router = build_router(state);
info!(%bind_address, "api-server 已完成 tracing 初始化并开始监听");
axum::serve(listener, router).await
}
async fn restore_app_state_for_startup(
config: AppConfig,
) -> Result<AppState, state::AppStateInitError> {
let fallback_config = config.clone();
match timeout(
AUTH_STORE_STARTUP_RESTORE_TIMEOUT,
AppState::try_restore_auth_store_from_spacetime(config),
)
.await
{
Ok(result) => result,
Err(_) => {
warn!(
timeout_seconds = AUTH_STORE_STARTUP_RESTORE_TIMEOUT.as_secs(),
"启动恢复认证快照超时,跳过远端恢复并继续启动 api-server"
);
AppState::new(fallback_config)
}
}
}
fn load_local_env_files() {
let shell_env_keys = protected_env_keys_from(env::vars());
for path in [".env", ".env.local", ".env.secrets.local"] {
load_env_file(path, &shell_env_keys);
}
}
fn protected_env_keys_from(vars: impl IntoIterator<Item = (String, String)>) -> HashSet<String> {
vars.into_iter()
.filter_map(|(key, value)| {
if value.trim().is_empty() {
None
} else {
Some(key)
}
})
.collect()
}
fn load_env_file(path: &str, shell_env_keys: &HashSet<String>) {
let Ok(raw_text) = fs::read_to_string(path) else {
return;
};
let raw_text = raw_text.trim_start_matches('\u{feff}');
for raw_line in raw_text.split('\n') {
let line = raw_line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let Some((raw_key, raw_value)) = line.split_once('=') else {
continue;
};
let key = raw_key.trim().trim_start_matches('\u{feff}');
if !is_valid_env_key(key) || shell_env_keys.contains(key) {
continue;
}
// 这里只在启动前、Tokio runtime 创建前写入进程环境,避免并发读写 env。
unsafe {
env::set_var(key, strip_env_value(raw_value));
}
}
}
fn strip_env_value(raw_value: &str) -> String {
let value = raw_value.trim_end_matches('\r');
if value.len() >= 2 {
let bytes = value.as_bytes();
let first = bytes[0];
let last = bytes[value.len() - 1];
if (first == b'"' && last == b'"') || (first == b'\'' && last == b'\'') {
return value[1..value.len() - 1].to_string();
}
}
value.to_string()
}
fn is_valid_env_key(key: &str) -> bool {
let mut chars = key.chars();
match chars.next() {
Some(first) if first == '_' || first.is_ascii_alphabetic() => {}
_ => return false,
}
chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
}
#[cfg(test)]
mod tests {
use super::{is_valid_env_key, protected_env_keys_from, strip_env_value};
#[test]
fn strip_env_value_removes_wrapping_quotes() {
assert_eq!(strip_env_value("\"true\""), "true");
assert_eq!(strip_env_value("'aliyun'"), "aliyun");
assert_eq!(strip_env_value("plain\r"), "plain");
}
#[test]
fn load_env_key_can_strip_utf8_bom_prefix() {
let key = "\u{feff}SMS_AUTH_ENABLED"
.trim()
.trim_start_matches('\u{feff}');
assert_eq!(key, "SMS_AUTH_ENABLED");
}
#[test]
fn is_valid_env_key_accepts_dotenv_key_subset() {
assert!(is_valid_env_key("SMS_AUTH_ENABLED"));
assert!(is_valid_env_key("_LOCAL_KEY_1"));
assert!(!is_valid_env_key("1_BAD"));
assert!(!is_valid_env_key("BAD-KEY"));
}
#[test]
fn empty_shell_env_does_not_protect_dotenv_value() {
let protected = protected_env_keys_from([
("ALIYUN_OSS_BUCKET".to_string(), "".to_string()),
("ALIYUN_OSS_ENDPOINT".to_string(), " ".to_string()),
(
"ALIYUN_OSS_ACCESS_KEY_ID".to_string(),
"configured".to_string(),
),
]);
assert!(!protected.contains("ALIYUN_OSS_BUCKET"));
assert!(!protected.contains("ALIYUN_OSS_ENDPOINT"));
assert!(protected.contains("ALIYUN_OSS_ACCESS_KEY_ID"));
}
}