Files
Genarrative/server-rs/crates/api-server/src/wechat_provider.rs
2026-04-21 19:17:31 +08:00

281 lines
9.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use module_auth::{WechatAuthScene, WechatIdentityProfile};
use reqwest::Client;
use serde::Deserialize;
use tracing::warn;
use url::Url;
use crate::{config::AppConfig, http_error::AppError};
use axum::http::StatusCode;
#[derive(Clone, Debug)]
pub enum WechatProvider {
Disabled,
Mock(MockWechatProvider),
Real(RealWechatProvider),
}
#[derive(Clone, Debug)]
pub struct MockWechatProvider {
mock_user_id: String,
mock_union_id: Option<String>,
mock_display_name: String,
mock_avatar_url: Option<String>,
}
#[derive(Clone, Debug)]
pub struct RealWechatProvider {
client: Client,
app_id: String,
app_secret: String,
authorize_endpoint: String,
access_token_endpoint: String,
user_info_endpoint: String,
}
#[derive(Debug, Deserialize)]
struct WechatAccessTokenResponse {
access_token: Option<String>,
openid: Option<String>,
unionid: Option<String>,
errmsg: Option<String>,
}
#[derive(Debug, Deserialize)]
struct WechatUserInfoResponse {
openid: Option<String>,
unionid: Option<String>,
nickname: Option<String>,
headimgurl: Option<String>,
errmsg: Option<String>,
}
pub fn build_wechat_provider(config: &AppConfig) -> WechatProvider {
if !config.wechat_auth_enabled {
return WechatProvider::Disabled;
}
if config
.wechat_auth_provider
.trim()
.eq_ignore_ascii_case("mock")
{
return WechatProvider::Mock(MockWechatProvider {
mock_user_id: config.wechat_mock_user_id.clone(),
mock_union_id: config.wechat_mock_union_id.clone(),
mock_display_name: config.wechat_mock_display_name.clone(),
mock_avatar_url: config.wechat_mock_avatar_url.clone(),
});
}
let Some(app_id) = config.wechat_app_id.clone() else {
return WechatProvider::Disabled;
};
let Some(app_secret) = config.wechat_app_secret.clone() else {
return WechatProvider::Disabled;
};
WechatProvider::Real(RealWechatProvider {
client: Client::new(),
app_id,
app_secret,
authorize_endpoint: config.wechat_authorize_endpoint.clone(),
access_token_endpoint: config.wechat_access_token_endpoint.clone(),
user_info_endpoint: config.wechat_user_info_endpoint.clone(),
})
}
impl WechatProvider {
pub fn build_authorization_url(
&self,
callback_url: &str,
state: &str,
scene: &WechatAuthScene,
) -> Result<String, AppError> {
match self {
Self::Disabled => {
Err(AppError::from_status(StatusCode::BAD_REQUEST).with_message("微信登录暂未启用"))
}
Self::Mock(_) => {
let mut callback = Url::parse(callback_url).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_message(format!("微信回调地址非法:{error}"))
})?;
callback
.query_pairs_mut()
.append_pair("mock_code", "wx-mock-code")
.append_pair("state", state);
Ok(callback.to_string())
}
Self::Real(provider) => provider.build_authorization_url(callback_url, state, scene),
}
}
pub async fn resolve_callback_profile(
&self,
code: Option<&str>,
mock_code: Option<&str>,
) -> Result<WechatIdentityProfile, AppError> {
match self {
Self::Disabled => {
Err(AppError::from_status(StatusCode::BAD_REQUEST).with_message("微信登录暂未启用"))
}
Self::Mock(provider) => Ok(provider.resolve_callback_profile(mock_code)),
Self::Real(provider) => provider.resolve_callback_profile(code).await,
}
}
}
impl MockWechatProvider {
fn resolve_callback_profile(&self, mock_code: Option<&str>) -> WechatIdentityProfile {
let provider_uid = mock_code
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or(self.mock_user_id.as_str())
.to_string();
WechatIdentityProfile {
provider_uid,
provider_union_id: self.mock_union_id.clone(),
display_name: Some(self.mock_display_name.clone()),
avatar_url: self.mock_avatar_url.clone(),
}
}
}
impl RealWechatProvider {
fn build_authorization_url(
&self,
callback_url: &str,
state: &str,
scene: &WechatAuthScene,
) -> Result<String, AppError> {
let mut url = Url::parse(match scene {
WechatAuthScene::Desktop => &self.authorize_endpoint,
WechatAuthScene::WechatInApp => "https://open.weixin.qq.com/connect/oauth2/authorize",
})
.map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_message(format!("微信授权地址非法:{error}"))
})?;
url.query_pairs_mut()
.append_pair("appid", &self.app_id)
.append_pair("redirect_uri", callback_url)
.append_pair("response_type", "code")
.append_pair(
"scope",
match scene {
WechatAuthScene::Desktop => "snsapi_login",
WechatAuthScene::WechatInApp => "snsapi_userinfo",
},
)
.append_pair("state", state);
Ok(format!("{url}#wechat_redirect"))
}
async fn resolve_callback_profile(
&self,
code: Option<&str>,
) -> Result<WechatIdentityProfile, AppError> {
let code = code
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
AppError::from_status(StatusCode::BAD_REQUEST).with_message("缺少微信授权 code")
})?;
let mut access_token_url = Url::parse(&self.access_token_endpoint).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_message(format!("微信 access_token 地址非法:{error}"))
})?;
access_token_url
.query_pairs_mut()
.append_pair("appid", &self.app_id)
.append_pair("secret", &self.app_secret)
.append_pair("code", code)
.append_pair("grant_type", "authorization_code");
let access_token_payload = self
.client
.get(access_token_url.as_str())
.send()
.await
.map_err(|error| {
warn!(error = %error, "微信 access_token 请求失败");
AppError::from_status(StatusCode::BAD_GATEWAY)
.with_message("微信登录失败access_token 请求失败")
})?
.json::<WechatAccessTokenResponse>()
.await
.map_err(|error| {
warn!(error = %error, "微信 access_token 响应解析失败");
AppError::from_status(StatusCode::BAD_GATEWAY)
.with_message("微信登录失败access_token 响应非法")
})?;
let access_token = access_token_payload
.access_token
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_message(format!(
"微信登录失败:{}",
access_token_payload
.errmsg
.unwrap_or_else(|| "缺少 access_token".to_string())
))
})?;
let openid = access_token_payload
.openid
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| {
AppError::from_status(StatusCode::BAD_GATEWAY)
.with_message("微信登录失败:缺少 openid")
})?;
let mut user_info_url = Url::parse(&self.user_info_endpoint).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_message(format!("微信用户信息地址非法:{error}"))
})?;
user_info_url
.query_pairs_mut()
.append_pair("access_token", &access_token)
.append_pair("openid", &openid)
.append_pair("lang", "zh_CN");
let user_info_payload = self
.client
.get(user_info_url.as_str())
.send()
.await
.map_err(|error| {
warn!(error = %error, "微信用户信息请求失败");
AppError::from_status(StatusCode::BAD_GATEWAY)
.with_message("微信登录失败:用户信息请求失败")
})?
.json::<WechatUserInfoResponse>()
.await
.map_err(|error| {
warn!(error = %error, "微信用户信息响应解析失败");
AppError::from_status(StatusCode::BAD_GATEWAY)
.with_message("微信登录失败:用户信息响应非法")
})?;
let provider_uid = user_info_payload
.openid
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_message(format!(
"微信登录失败:{}",
user_info_payload
.errmsg
.unwrap_or_else(|| "缺少 openid".to_string())
))
})?;
Ok(WechatIdentityProfile {
provider_uid,
provider_union_id: user_info_payload.unionid.or(access_token_payload.unionid),
display_name: user_info_payload.nickname,
avatar_url: user_info_payload.headimgurl,
})
}
}