Files
Genarrative/server-rs/crates/api-server/src/main.rs

458 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#![recursion_limit = "256"]
mod admin;
mod ai_generation_drafts;
mod ai_tasks;
mod api_response;
mod app;
mod asset_billing;
mod assets;
mod auth;
mod auth_me;
mod auth_payload;
mod auth_public_user;
mod auth_session;
mod auth_sessions;
mod backpressure;
mod bark_battle;
mod big_fish;
mod big_fish_agent_turn;
mod big_fish_draft_compiler;
mod character_animation_assets;
mod character_visual_assets;
mod config;
mod creation_agent_anchor_templates;
mod creation_agent_chat;
mod creation_agent_document_input;
mod creation_agent_llm_turn;
mod creation_entry_config;
mod creative_agent;
mod creative_agent_sse;
mod custom_world;
mod custom_world_agent_entities;
mod custom_world_agent_turn;
mod custom_world_ai;
mod custom_world_asset_prompts;
mod custom_world_foundation_draft;
mod custom_world_result_prompts;
mod custom_world_rpg_draft_prompts;
mod edutainment_baby_drawing;
mod edutainment_baby_object;
mod error_middleware;
mod external_api_audit;
pub(crate) mod generated_asset_sheets;
mod generated_image_assets;
mod health;
mod http_error;
mod hyper3d_generation;
mod jump_hop;
mod llm;
mod llm_model_routing;
mod login_options;
mod logout;
mod logout_all;
mod match3d;
mod modules;
mod openai_image_generation;
mod password_entry;
mod password_management;
mod phone_auth;
mod platform_errors;
mod process_metrics;
mod profile_identity;
mod prompt;
mod public_work;
mod puzzle;
mod puzzle_agent_turn;
mod puzzle_gallery_cache;
mod refresh_session;
mod registration_reward;
mod request_context;
mod response_headers;
mod runtime_browse_history;
mod runtime_chat;
mod runtime_chat_plain;
mod runtime_inventory;
mod runtime_profile;
mod runtime_save;
mod runtime_settings;
mod session_client;
mod square_hole;
mod square_hole_agent_turn;
mod state;
mod story_battles;
mod story_sessions;
mod telemetry;
mod tracking;
mod tracking_outbox;
mod vector_engine_audio_generation;
mod visual_novel;
mod volcengine_speech;
mod wechat_auth;
mod wechat_pay;
mod wechat_provider;
mod wooden_fish;
mod work_author;
mod work_play_tracking;
use shared_logging::{OtelConfig, init_tracing};
use socket2::{Domain, Protocol, Socket, Type};
use std::{
collections::HashSet,
env, fs, future, io,
net::{SocketAddr, TcpListener as StdTcpListener},
panic,
sync::Arc,
thread,
time::Duration,
};
use tokio::net::TcpListener;
use tokio::runtime::Builder as TokioRuntimeBuilder;
use tokio::time::timeout;
use tracing::{error, info, warn};
use crate::{
app::{build_router, build_spacetime_unavailable_router},
config::AppConfig,
state::{AppState, AppStateInitError},
tracking_outbox::TrackingOutbox,
};
const API_SERVER_STARTUP_STACK_SIZE_BYTES: usize = 32 * 1024 * 1024;
const AUTH_STORE_STARTUP_RESTORE_TIMEOUT: Duration = Duration::from_secs(8);
#[derive(Clone)]
struct ShutdownContext {
app_state: Option<AppState>,
tracking_outbox: Option<Arc<TrackingOutbox>>,
outbox_flush_timeout: Duration,
}
fn main() -> Result<(), io::Error> {
// Windows 本地调试下 Axum 路由树和启动恢复链较重,显式放大启动线程栈,避免 debug 构建在进入监听前栈溢出。
let server_thread = thread::Builder::new()
.name("api-server-bootstrap".to_string())
.stack_size(API_SERVER_STARTUP_STACK_SIZE_BYTES)
.spawn(|| {
load_local_env_files();
let config = AppConfig::from_env();
let mut runtime_builder = TokioRuntimeBuilder::new_multi_thread();
runtime_builder
.enable_all()
.thread_name("api-server-worker")
.thread_stack_size(API_SERVER_STARTUP_STACK_SIZE_BYTES);
if let Some(worker_threads) = config.worker_threads {
runtime_builder.worker_threads(worker_threads);
}
runtime_builder.build()?.block_on(run_server(config))
})?;
match server_thread.join() {
Ok(result) => result,
Err(payload) => panic::resume_unwind(payload),
}
}
async fn run_server(config: AppConfig) -> Result<(), io::Error> {
init_tracing(
&config.log_filter,
OtelConfig {
enabled: config.otel_enabled,
},
)?;
process_metrics::register_process_metrics();
telemetry::register_http_runtime_metrics();
let bind_address = config.bind_socket_addr();
let listen_backlog = config.listen_backlog;
let worker_threads = config.worker_threads;
let otel_enabled = config.otel_enabled;
let outbox_flush_timeout = config.shutdown_outbox_flush_timeout;
let listener = build_tcp_listener(bind_address, listen_backlog)?;
let (router, shutdown_context) = match restore_app_state_for_startup(config).await {
Ok(state) => {
state.puzzle_gallery_cache().spawn_cleanup_task();
let tracking_outbox = state.tracking_outbox();
if let Some(outbox) = tracking_outbox.clone() {
outbox.spawn_worker();
}
(
build_router(state.clone()),
ShutdownContext {
app_state: Some(state),
tracking_outbox,
outbox_flush_timeout,
},
)
}
Err(AppStateInitError::DependencyUnavailable(message)) => (
build_spacetime_unavailable_router(message),
ShutdownContext {
app_state: None,
tracking_outbox: None,
outbox_flush_timeout,
},
),
Err(error) => {
return Err(std::io::Error::other(format!(
"初始化应用状态失败:{error}"
)));
}
};
info!(
%bind_address,
listen_backlog,
worker_threads = worker_threads.unwrap_or(0),
otel_enabled,
"api-server 已完成 tracing 初始化并开始监听"
);
let result = axum::serve(listener, router)
.with_graceful_shutdown(shutdown_signal(shutdown_context.clone()))
.await;
finalize_shutdown(shutdown_context).await;
result
}
async fn shutdown_signal(context: ShutdownContext) {
let signal = wait_for_shutdown_signal().await;
if let Some(state) = context.app_state.as_ref() {
state.mark_not_ready();
}
info!(
signal,
"api-server 收到退出信号,已标记 readiness 不可用并开始排空 HTTP 请求"
);
}
async fn wait_for_shutdown_signal() -> &'static str {
#[cfg(unix)]
{
tokio::select! {
signal = wait_for_ctrl_c_signal() => signal,
signal = wait_for_sigterm_signal() => signal,
}
}
#[cfg(not(unix))]
{
wait_for_ctrl_c_signal().await
}
}
async fn wait_for_ctrl_c_signal() -> &'static str {
if let Err(error) = tokio::signal::ctrl_c().await {
error!(error = %error, "监听 SIGINT 失败,无法通过 Ctrl-C 触发优雅退出");
future::pending::<()>().await;
}
"sigint"
}
#[cfg(unix)]
async fn wait_for_sigterm_signal() -> &'static str {
let mut signal = match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
{
Ok(signal) => signal,
Err(error) => {
error!(error = %error, "监听 SIGTERM 失败,无法通过 systemd terminate 触发优雅退出");
future::pending::<()>().await;
unreachable!("pending future never returns");
}
};
signal.recv().await;
"sigterm"
}
async fn finalize_shutdown(context: ShutdownContext) {
if let Some(state) = context.app_state.as_ref() {
state.mark_not_ready();
}
let Some(outbox) = context.tracking_outbox else {
return;
};
if context.outbox_flush_timeout.is_zero() {
warn!("api-server 退出时 tracking outbox flush timeout 为 0跳过主动 flush");
return;
}
let timeout_ms = context
.outbox_flush_timeout
.as_millis()
.min(u128::from(u64::MAX)) as u64;
info!(timeout_ms, "api-server 退出前封存并 flush tracking outbox");
match timeout(context.outbox_flush_timeout, outbox.flush_for_shutdown()).await {
Ok(Ok(())) => {
info!("api-server 退出前 tracking outbox flush 完成");
}
Ok(Err(error)) => {
warn!(
error = %error,
"api-server 退出前 tracking outbox flush 未完成,已保留本地文件等待下次启动重试"
);
}
Err(_) => {
warn!(
timeout_ms,
"api-server 退出前 tracking outbox flush 超时,已保留本地文件等待下次启动重试"
);
}
}
}
fn build_tcp_listener(
bind_address: SocketAddr,
listen_backlog: i32,
) -> Result<TcpListener, io::Error> {
let domain = Domain::for_address(bind_address);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
socket.set_reuse_address(true)?;
socket.set_nonblocking(true)?;
socket.bind(&bind_address.into())?;
socket.listen(listen_backlog)?;
TcpListener::from_std(StdTcpListener::from(socket))
}
async fn restore_app_state_for_startup(
config: AppConfig,
) -> Result<AppState, state::AppStateInitError> {
match timeout(
AUTH_STORE_STARTUP_RESTORE_TIMEOUT,
AppState::try_restore_auth_store_from_spacetime(config),
)
.await
{
Ok(result) => result,
Err(_) => {
error!(
timeout_seconds = AUTH_STORE_STARTUP_RESTORE_TIMEOUT.as_secs(),
"启动等待 SpacetimeDB 恢复认证快照超时api-server 将进入依赖不可用模式"
);
Err(state::AppStateInitError::DependencyUnavailable(
"SpacetimeDB 启动恢复认证快照超时".to_string(),
))
}
}
}
fn load_local_env_files() {
let shell_env_keys = protected_env_keys_from(env::vars());
for path in [".env", ".env.local", ".env.secrets.local"] {
load_env_file(path, &shell_env_keys);
}
}
fn protected_env_keys_from(vars: impl IntoIterator<Item = (String, String)>) -> HashSet<String> {
vars.into_iter()
.filter_map(|(key, value)| {
if value.trim().is_empty() {
None
} else {
Some(key)
}
})
.collect()
}
fn load_env_file(path: &str, shell_env_keys: &HashSet<String>) {
let Ok(raw_text) = fs::read_to_string(path) else {
return;
};
let raw_text = raw_text.trim_start_matches('\u{feff}');
for raw_line in raw_text.split('\n') {
let line = raw_line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let Some((raw_key, raw_value)) = line.split_once('=') else {
continue;
};
let key = raw_key.trim().trim_start_matches('\u{feff}');
if !is_valid_env_key(key) || shell_env_keys.contains(key) {
continue;
}
// 这里只在启动前、Tokio runtime 创建前写入进程环境,避免并发读写 env。
unsafe {
env::set_var(key, strip_env_value(raw_value));
}
}
}
fn strip_env_value(raw_value: &str) -> String {
let value = raw_value.trim_end_matches('\r');
if value.len() >= 2 {
let bytes = value.as_bytes();
let first = bytes[0];
let last = bytes[value.len() - 1];
if (first == b'"' && last == b'"') || (first == b'\'' && last == b'\'') {
return value[1..value.len() - 1].to_string();
}
}
value.to_string()
}
fn is_valid_env_key(key: &str) -> bool {
let mut chars = key.chars();
match chars.next() {
Some(first) if first == '_' || first.is_ascii_alphabetic() => {}
_ => return false,
}
chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
}
#[cfg(test)]
mod tests {
use super::{is_valid_env_key, protected_env_keys_from, strip_env_value};
#[test]
fn strip_env_value_removes_wrapping_quotes() {
assert_eq!(strip_env_value("\"true\""), "true");
assert_eq!(strip_env_value("'aliyun'"), "aliyun");
assert_eq!(strip_env_value("plain\r"), "plain");
}
#[test]
fn load_env_key_can_strip_utf8_bom_prefix() {
let key = "\u{feff}SMS_AUTH_ENABLED"
.trim()
.trim_start_matches('\u{feff}');
assert_eq!(key, "SMS_AUTH_ENABLED");
}
#[test]
fn is_valid_env_key_accepts_dotenv_key_subset() {
assert!(is_valid_env_key("SMS_AUTH_ENABLED"));
assert!(is_valid_env_key("_LOCAL_KEY_1"));
assert!(!is_valid_env_key("1_BAD"));
assert!(!is_valid_env_key("BAD-KEY"));
}
#[test]
fn empty_shell_env_does_not_protect_dotenv_value() {
let protected = protected_env_keys_from([
("ALIYUN_OSS_BUCKET".to_string(), "".to_string()),
("ALIYUN_OSS_ENDPOINT".to_string(), " ".to_string()),
(
"ALIYUN_OSS_ACCESS_KEY_ID".to_string(),
"configured".to_string(),
),
]);
assert!(!protected.contains("ALIYUN_OSS_BUCKET"));
assert!(!protected.contains("ALIYUN_OSS_ENDPOINT"));
assert!(protected.contains("ALIYUN_OSS_ACCESS_KEY_ID"));
}
}