feat: 重构项目结构并添加企业微信集成功能
- 移除旧的用户和物品相关模块及测试 - 添加企业微信路由、服务和认证功能 - 实现企业微信API集成包括获取access_token、用户信息等 - 添加统一响应模型和JWT认证工具 - 重构主应用配置为环境变量驱动 - 清理不必要的文档字符串和注释
This commit is contained in:
@@ -1,10 +1,3 @@
|
||||
"""
|
||||
数据库模型模块
|
||||
|
||||
该模块定义了 SQLAlchemy ORM 模型,用于与数据库进行交互。
|
||||
当前包含日志表的模型定义。
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, BIGINT
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from datetime import datetime
|
||||
|
||||
30
src/main.py
30
src/main.py
@@ -1,28 +1,23 @@
|
||||
"""
|
||||
FastAPI 应用主入口文件
|
||||
|
||||
该文件负责初始化 FastAPI 应用实例,配置中间件,
|
||||
注册路由以及定义根路径和健康检查端点。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 将项目根目录添加到 Python 路径中,确保可以正确导入项目模块
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.routers import users, items
|
||||
|
||||
# 初始化 FastAPI 应用实例
|
||||
APP_NAME = os.getenv("APP_NAME", "information-sign")
|
||||
DEBUG = os.getenv("DEBUG", "True").lower() == "true"
|
||||
API_PREFIX = os.getenv("API_PREFIX", "/api/v1")
|
||||
APP_VERSION = os.getenv("APP_VERSION", "1.0.0")
|
||||
|
||||
app = FastAPI(
|
||||
title="规范FastApi 开发基础框架",
|
||||
description="规范的FastApi 开发基础框架",
|
||||
version="1.0.0"
|
||||
title=APP_NAME,
|
||||
description="信息灯板",
|
||||
debug=DEBUG,
|
||||
version=APP_VERSION
|
||||
)
|
||||
|
||||
# 配置 CORS 中间件,允许所有来源、凭证、方法和头部
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -33,10 +28,8 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# 注册路由模块
|
||||
# 用户相关路由,前缀为 /users,标签为 users
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
# 物品相关路由,前缀为 /items,标签为 items
|
||||
app.include_router(items.router, prefix="/items", tags=["items"])
|
||||
from src.routers import wechat_router
|
||||
app.include_router(wechat_router, prefix=API_PREFIX+"/wechat", tags=["企业微信"])
|
||||
|
||||
# 根路径端点,返回欢迎信息
|
||||
@app.get("/")
|
||||
@@ -48,7 +41,6 @@ async def root():
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
# 当直接运行此文件时启动应用服务器
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
物品数据模型模块
|
||||
|
||||
该模块定义了物品相关的数据模型,使用 Pydantic 进行数据验证和序列化。
|
||||
包括基础物品模型、创建物品模型和完整物品模型。
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
class ItemBase(BaseModel):
|
||||
"""
|
||||
物品基础模型,定义了物品的基本信息字段
|
||||
"""
|
||||
name: str # 物品名称,必需字段
|
||||
description: Optional[str] = None # 物品描述,可选字段
|
||||
price: float # 物品价格,必需字段
|
||||
|
||||
class ItemCreate(ItemBase):
|
||||
"""
|
||||
物品创建模型,继承自 ItemBase
|
||||
当前与 ItemBase 相同,但保留独立的类以便未来扩展
|
||||
"""
|
||||
pass
|
||||
|
||||
class Item(ItemBase):
|
||||
"""
|
||||
完整物品模型,继承自 ItemBase,增加了数据库相关字段
|
||||
"""
|
||||
id: int # 物品唯一标识符
|
||||
created_at: datetime = None # 物品创建时间,可选字段
|
||||
|
||||
class Config:
|
||||
# 允许从 ORM 模型转换为 Pydantic 模型
|
||||
from_attributes = True
|
||||
115
src/models/response.py
Normal file
115
src/models/response.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import Any, Dict, Generic, Optional, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 定义泛型T,用于表示响应数据的类型
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ResponseModel(BaseModel, Generic[T]):
|
||||
"""
|
||||
统一响应格式模型
|
||||
|
||||
Attributes:
|
||||
code (int): 响应状态码,与HTTP状态码一致
|
||||
message (str): 响应消息,描述操作结果
|
||||
data (Optional[T]): 响应数据,可以是任意类型的数据
|
||||
success (bool): 操作是否成功
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[T] = None
|
||||
success: bool
|
||||
|
||||
|
||||
class ListResponseModel(ResponseModel[T]):
|
||||
"""
|
||||
列表数据响应格式模型
|
||||
|
||||
Attributes:
|
||||
total (Optional[int]): 数据总数,用于分页
|
||||
page (Optional[int]): 当前页码
|
||||
size (Optional[int]): 每页大小
|
||||
"""
|
||||
|
||||
total: Optional[int] = None
|
||||
page: Optional[int] = None
|
||||
size: Optional[int] = None
|
||||
|
||||
|
||||
class TokenResponseModel(BaseModel):
|
||||
"""
|
||||
认证令牌响应格式模型
|
||||
|
||||
Attributes:
|
||||
access_token (str): 访问令牌
|
||||
token_type (str): 令牌类型
|
||||
"""
|
||||
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
# 成功响应的快捷函数
|
||||
def success_response(
|
||||
data: Optional[T] = None, message: str = "操作成功", code: int = 200
|
||||
) -> ResponseModel[T]:
|
||||
"""
|
||||
成功响应构造函数
|
||||
|
||||
参数:
|
||||
data: 响应数据
|
||||
message: 响应消息
|
||||
code: 响应状态码
|
||||
|
||||
返回:
|
||||
ResponseModel: 成功响应对象
|
||||
"""
|
||||
return ResponseModel(code=code, message=message, data=data, success=True)
|
||||
|
||||
|
||||
# 错误响应的快捷函数
|
||||
def error_response(
|
||||
message: str = "操作失败", code: int = 400, data: Optional[T] = None
|
||||
) -> ResponseModel[T]:
|
||||
"""
|
||||
错误响应构造函数
|
||||
|
||||
参数:
|
||||
message: 错误消息
|
||||
code: 错误状态码
|
||||
data: 错应数据
|
||||
|
||||
返回:
|
||||
ResponseModel: 错误响应对象
|
||||
"""
|
||||
return ResponseModel(code=code, message=message, data=data, success=False)
|
||||
|
||||
|
||||
# 分页成功响应的快捷函数
|
||||
def paginated_response(
|
||||
data: T, total: int, page: int, size: int, message: str = "操作成功"
|
||||
) -> ListResponseModel[T]:
|
||||
"""
|
||||
分页响应构造函数
|
||||
|
||||
参数:
|
||||
data: 响应数据
|
||||
total: 数据总数
|
||||
page: 当前页码
|
||||
size: 每页大小
|
||||
message: 响应消息
|
||||
|
||||
返回:
|
||||
ListResponseModel: 分页成功响应对象
|
||||
"""
|
||||
return ListResponseModel(
|
||||
code=200,
|
||||
message=message,
|
||||
data=data,
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
用户数据模型模块
|
||||
|
||||
该模块定义了用户相关的数据模型,使用 Pydantic 进行数据验证和序列化。
|
||||
包括基础用户模型、创建用户模型和完整用户模型。
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""
|
||||
用户基础模型,定义了用户的基本信息字段
|
||||
"""
|
||||
email: str # 用户邮箱,必需字段
|
||||
first_name: str # 用户名字,必需字段
|
||||
last_name: str # 用户姓氏,必需字段
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""
|
||||
用户创建模型,继承自 UserBase,增加了密码字段
|
||||
"""
|
||||
password: str # 用户密码,必需字段
|
||||
|
||||
class User(UserBase):
|
||||
"""
|
||||
完整用户模型,继承自 UserBase,增加了数据库相关字段
|
||||
"""
|
||||
id: int # 用户唯一标识符
|
||||
created_at: datetime = None # 用户创建时间,可选字段
|
||||
|
||||
class Config:
|
||||
# 允许从 ORM 模型转换为 Pydantic 模型
|
||||
from_attributes = True
|
||||
@@ -0,0 +1,5 @@
|
||||
"""路由模块包"""
|
||||
|
||||
from src.routers.wechat import wechat_router
|
||||
|
||||
__all__ = ["chat_router", "customer_allot_router", "wechat_router"]
|
||||
@@ -1,123 +0,0 @@
|
||||
"""
|
||||
物品路由器模块
|
||||
|
||||
该模块定义了物品相关的 RESTful API 端点,
|
||||
包括创建、读取、更新和删除物品等功能。
|
||||
注意:当前实现使用内存存储,实际应用中应替换为数据库存储。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 将项目根目录添加到 Python 路径中,确保可以正确导入项目模块
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import List
|
||||
from src.models.item import Item, ItemCreate
|
||||
|
||||
# 创建 API 路由器实例
|
||||
router = APIRouter()
|
||||
|
||||
# 模拟数据库存储,实际应用中应使用真实数据库
|
||||
items_db = []
|
||||
|
||||
# 创建物品端点
|
||||
# 接收 ItemCreate 模型数据,返回创建的 Item 对象
|
||||
@router.post("/", response_model=Item, status_code=201)
|
||||
async def create_item(item: ItemCreate):
|
||||
"""
|
||||
创建新物品
|
||||
|
||||
参数:
|
||||
- item: ItemCreate 模型,包含物品创建所需信息
|
||||
|
||||
返回:
|
||||
- Item: 创建成功的物品对象
|
||||
"""
|
||||
# 创建新物品对象并添加到数据库
|
||||
new_item = Item(id=len(items_db) + 1, **item.dict())
|
||||
items_db.append(new_item)
|
||||
return new_item
|
||||
|
||||
# 根据物品ID获取物品信息端点
|
||||
@router.get("/{item_id}", response_model=Item)
|
||||
async def read_item(item_id: int):
|
||||
"""
|
||||
根据物品ID获取物品信息
|
||||
|
||||
参数:
|
||||
- item_id: 物品ID
|
||||
|
||||
返回:
|
||||
- Item: 找到的物品对象
|
||||
|
||||
异常:
|
||||
- HTTPException: 当物品不存在时返回 404 错误
|
||||
"""
|
||||
# 查找指定ID的物品
|
||||
for item in items_db:
|
||||
if item.id == item_id:
|
||||
return item
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
# 获取物品列表端点,支持分页
|
||||
@router.get("/", response_model=List[Item])
|
||||
async def read_items(skip: int = 0, limit: int = 100):
|
||||
"""
|
||||
获取物品列表,支持分页
|
||||
|
||||
参数:
|
||||
- skip: 跳过的记录数,默认为 0
|
||||
- limit: 返回的记录数,默认为 100
|
||||
|
||||
返回:
|
||||
- List[Item]: 物品对象列表
|
||||
"""
|
||||
return items_db[skip : skip + limit]
|
||||
|
||||
# 更新物品信息端点
|
||||
@router.put("/{item_id}", response_model=Item)
|
||||
async def update_item(item_id: int, item_update: ItemCreate):
|
||||
"""
|
||||
更新物品信息
|
||||
|
||||
参数:
|
||||
- item_id: 要更新的物品ID
|
||||
- item_update: ItemCreate 模型,包含更新后的物品信息
|
||||
|
||||
返回:
|
||||
- Item: 更新后的物品对象
|
||||
|
||||
异常:
|
||||
- HTTPException: 当物品不存在时返回 404 错误
|
||||
"""
|
||||
# 查找并更新指定ID的物品
|
||||
for index, item in enumerate(items_db):
|
||||
if item.id == item_id:
|
||||
updated_item = Item(id=item_id, **item_update.dict())
|
||||
items_db[index] = updated_item
|
||||
return updated_item
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
# 删除物品端点
|
||||
@router.delete("/{item_id}", status_code=204)
|
||||
async def delete_item(item_id: int):
|
||||
"""
|
||||
删除物品
|
||||
|
||||
参数:
|
||||
- item_id: 要删除的物品ID
|
||||
|
||||
返回:
|
||||
- 无内容,成功时返回 204 状态码
|
||||
|
||||
异常:
|
||||
- HTTPException: 当物品不存在时返回 404 错误
|
||||
"""
|
||||
# 查找并删除指定ID的物品
|
||||
for index, item in enumerate(items_db):
|
||||
if item.id == item_id:
|
||||
items_db.pop(index)
|
||||
return
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
@@ -1,131 +0,0 @@
|
||||
"""
|
||||
用户路由器模块
|
||||
|
||||
该模块定义了用户相关的 RESTful API 端点,
|
||||
包括创建、读取、更新和删除用户等功能。
|
||||
注意:当前实现使用内存存储,实际应用中应替换为数据库存储。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 将项目根目录添加到 Python 路径中,确保可以正确导入项目模块
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import List
|
||||
from src.models.user import User, UserCreate
|
||||
|
||||
# 创建 API 路由器实例
|
||||
router = APIRouter()
|
||||
|
||||
# 模拟数据库存储,实际应用中应使用真实数据库
|
||||
users_db = []
|
||||
|
||||
# 创建用户端点
|
||||
# 接收 UserCreate 模型数据,返回创建的 User 对象
|
||||
@router.post("/", response_model=User, status_code=201)
|
||||
async def create_user(user: UserCreate):
|
||||
"""
|
||||
创建新用户
|
||||
|
||||
参数:
|
||||
- user: UserCreate 模型,包含用户创建所需信息
|
||||
|
||||
返回:
|
||||
- User: 创建成功的用户对象
|
||||
|
||||
异常:
|
||||
- HTTPException: 当邮箱已被注册时返回 400 错误
|
||||
"""
|
||||
# 检查邮箱是否已存在
|
||||
for existing_user in users_db:
|
||||
if existing_user.email == user.email:
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
|
||||
# 创建新用户对象并添加到数据库
|
||||
new_user = User(id=len(users_db) + 1, **user.dict())
|
||||
users_db.append(new_user)
|
||||
return new_user
|
||||
|
||||
# 根据用户ID获取用户信息端点
|
||||
@router.get("/{user_id}", response_model=User)
|
||||
async def read_user(user_id: int):
|
||||
"""
|
||||
根据用户ID获取用户信息
|
||||
|
||||
参数:
|
||||
- user_id: 用户ID
|
||||
|
||||
返回:
|
||||
- User: 找到的用户对象
|
||||
|
||||
异常:
|
||||
- HTTPException: 当用户不存在时返回 404 错误
|
||||
"""
|
||||
# 查找指定ID的用户
|
||||
for user in users_db:
|
||||
if user.id == user_id:
|
||||
return user
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# 获取用户列表端点,支持分页
|
||||
@router.get("/", response_model=List[User])
|
||||
async def read_users(skip: int = 0, limit: int = 100):
|
||||
"""
|
||||
获取用户列表,支持分页
|
||||
|
||||
参数:
|
||||
- skip: 跳过的记录数,默认为 0
|
||||
- limit: 返回的记录数,默认为 100
|
||||
|
||||
返回:
|
||||
- List[User]: 用户对象列表
|
||||
"""
|
||||
return users_db[skip : skip + limit]
|
||||
|
||||
# 更新用户信息端点
|
||||
@router.put("/{user_id}", response_model=User)
|
||||
async def update_user(user_id: int, user_update: UserCreate):
|
||||
"""
|
||||
更新用户信息
|
||||
|
||||
参数:
|
||||
- user_id: 要更新的用户ID
|
||||
- user_update: UserCreate 模型,包含更新后的用户信息
|
||||
|
||||
返回:
|
||||
- User: 更新后的用户对象
|
||||
|
||||
异常:
|
||||
- HTTPException: 当用户不存在时返回 404 错误
|
||||
"""
|
||||
# 查找并更新指定ID的用户
|
||||
for index, user in enumerate(users_db):
|
||||
if user.id == user_id:
|
||||
updated_user = User(id=user_id, **user_update.dict())
|
||||
users_db[index] = updated_user
|
||||
return updated_user
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# 删除用户端点
|
||||
@router.delete("/{user_id}", status_code=204)
|
||||
async def delete_user(user_id: int):
|
||||
"""
|
||||
删除用户
|
||||
|
||||
参数:
|
||||
- user_id: 要删除的用户ID
|
||||
|
||||
返回:
|
||||
- 无内容,成功时返回 204 状态码
|
||||
|
||||
异常:
|
||||
- HTTPException: 当用户不存在时返回 404 错误
|
||||
"""
|
||||
# 查找并删除指定ID的用户
|
||||
for index, user in enumerate(users_db):
|
||||
if user.id == user_id:
|
||||
users_db.pop(index)
|
||||
return
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
364
src/routers/wechat.py
Normal file
364
src/routers/wechat.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Optional, Dict
|
||||
import os
|
||||
from src.services.wechat import (
|
||||
get_wechat_access_token,
|
||||
get_userid_by_mobile,
|
||||
send_textcard_message,
|
||||
authenticate_wechat_user,
|
||||
get_department_list,
|
||||
get_user_detail,
|
||||
get_customer_list,
|
||||
get_external_contact_detail
|
||||
)
|
||||
from src.models.response import success_response, error_response
|
||||
from src.utils.auth import create_access_token
|
||||
|
||||
# 创建路由实例
|
||||
wechat_router = APIRouter()
|
||||
|
||||
|
||||
@wechat_router.get("/access-token", summary="获取微信AccessToken")
|
||||
async def api_get_access_token():
|
||||
"""
|
||||
获取企业微信access_token
|
||||
需要在环境变量中配置CORPID和CORPSECRET
|
||||
"""
|
||||
try:
|
||||
access_token = get_wechat_access_token()
|
||||
return success_response(
|
||||
data={"access_token": access_token},
|
||||
message="获取access_token成功",
|
||||
code=200
|
||||
)
|
||||
except ValueError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=400
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=500
|
||||
)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
message=f"获取access_token失败: {str(e)}",
|
||||
code=500
|
||||
)
|
||||
|
||||
|
||||
@wechat_router.get("/userid-by-mobile", summary="根据手机号获取企业微信用户ID")
|
||||
async def api_get_userid_by_mobile(mobile: str):
|
||||
"""
|
||||
根据手机号获取企业微信成员userid
|
||||
|
||||
- **mobile**: 成员手机号(5~32字节)
|
||||
"""
|
||||
if not mobile:
|
||||
return error_response(
|
||||
message="手机号不能为空",
|
||||
code=400
|
||||
)
|
||||
|
||||
try:
|
||||
userid = get_userid_by_mobile(mobile)
|
||||
return success_response(
|
||||
data={"userid": userid, "mobile": mobile},
|
||||
message="获取用户ID成功",
|
||||
code=200
|
||||
)
|
||||
except ValueError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=400
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=500
|
||||
)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
message=f"获取用户ID失败: {str(e)}",
|
||||
code=500
|
||||
)
|
||||
|
||||
|
||||
@wechat_router.post("/send-textcard", summary="发送文本卡片消息")
|
||||
async def api_send_textcard_message(
|
||||
touser: str,
|
||||
agentid: int,
|
||||
title: str,
|
||||
description: str,
|
||||
url: str,
|
||||
btntxt: Optional[str] = "详情",
|
||||
toparty: Optional[str] = "",
|
||||
totag: Optional[str] = "",
|
||||
enable_id_trans: Optional[int] = 0,
|
||||
enable_duplicate_check: Optional[int] = 0,
|
||||
duplicate_check_interval: Optional[int] = 1800
|
||||
):
|
||||
"""
|
||||
发送企业微信文本卡片消息
|
||||
|
||||
- **touser**: 成员ID列表,最多1000个,用 '|' 分隔;特殊情况填 '@all' 表示全部成员
|
||||
- **agentid**: 企业应用ID
|
||||
- **title**: 标题,不超过128字符
|
||||
- **description**: 描述,不超过512字符,支持div class="gray/highlight/normal"
|
||||
- **url**: 点击跳转链接,需含http/https
|
||||
- **btntxt**: 按钮文字,不超过4字符,默认"详情"
|
||||
- **toparty**: 部门ID列表,最多100个,用 '|' 分隔
|
||||
- **totag**: 标签ID列表,最多100个,用 '|' 分隔
|
||||
- **enable_id_trans**: 是否开启ID转译,0否1是,默认0
|
||||
- **enable_duplicate_check**: 是否开启重复消息检查,0否1是,默认0
|
||||
- **duplicate_check_interval**: 重复检查时间间隔,秒,默认1800,最大14400
|
||||
"""
|
||||
try:
|
||||
result = send_textcard_message(
|
||||
touser=touser,
|
||||
agentid=agentid,
|
||||
title=title,
|
||||
description=description,
|
||||
url=url,
|
||||
btntxt=btntxt,
|
||||
toparty=toparty,
|
||||
totag=totag,
|
||||
enable_id_trans=enable_id_trans,
|
||||
enable_duplicate_check=enable_duplicate_check,
|
||||
duplicate_check_interval=duplicate_check_interval
|
||||
)
|
||||
return success_response(
|
||||
data=result,
|
||||
message="发送文本卡片消息成功",
|
||||
code=200
|
||||
)
|
||||
except ValueError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=400
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=500
|
||||
)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
message=f"发送文本卡片消息失败: {str(e)}",
|
||||
code=500
|
||||
)
|
||||
|
||||
|
||||
@wechat_router.get("/auth", summary="企业微信用户授权")
|
||||
async def api_wechat_auth(code: str):
|
||||
"""
|
||||
企业微信用户授权接口
|
||||
通过授权code获取用户信息并返回内部JWT令牌
|
||||
|
||||
- **code**: 企业微信授权code
|
||||
"""
|
||||
if not code:
|
||||
return error_response(
|
||||
message="授权code不能为空",
|
||||
code=400
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. 通过code获取用户信息
|
||||
user_info = await authenticate_wechat_user(code)
|
||||
|
||||
# 2. 获取用户详情(包含部门信息)
|
||||
user_detail = get_user_detail(user_info["userid"])
|
||||
|
||||
# 3. 获取管理部门ID列表
|
||||
manager_dept_ids = os.getenv("MANAGER_DEPARTMENT_IDS", "").split(",")
|
||||
manager_dept_ids = [int(dept_id.strip()) for dept_id in manager_dept_ids if dept_id.strip()]
|
||||
|
||||
# 4. 检查用户是否属于管理部门
|
||||
user_departments = user_detail.get("department", [])
|
||||
is_manager = any(dept in manager_dept_ids for dept in user_departments)
|
||||
|
||||
# 5. 设置用户角色
|
||||
user_info["role"] = "admin" if is_manager else "user"
|
||||
# 6. 生成JWT令牌(包含用户角色和姓名信息)
|
||||
access_token = create_access_token(
|
||||
subject=user_info["userid"],
|
||||
additional_claims={
|
||||
"role": user_info["role"],
|
||||
"name": user_detail.get("name", "")
|
||||
}
|
||||
)
|
||||
# 7. 返回令牌和用户信息
|
||||
return success_response(
|
||||
data={
|
||||
"access_token": access_token,
|
||||
"user": {
|
||||
"userid": user_info["userid"],
|
||||
"name": user_detail.get("name", ""),
|
||||
"role": user_info["role"]
|
||||
}
|
||||
},
|
||||
message="授权成功",
|
||||
code=200
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=401
|
||||
)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
message=f"授权失败: {str(e)}",
|
||||
code=500
|
||||
)
|
||||
|
||||
|
||||
@wechat_router.get("/departments", summary="获取部门列表")
|
||||
async def api_get_department_list(id: Optional[int] = None):
|
||||
"""
|
||||
获取企业微信部门列表
|
||||
|
||||
- **id**: 部门id(可选)。获取指定部门及其下的子部门(递归),不填则获取全量组织架构
|
||||
"""
|
||||
try:
|
||||
departments = get_department_list(id=id)
|
||||
return success_response(
|
||||
data=departments,
|
||||
message="获取部门列表成功",
|
||||
code=200
|
||||
)
|
||||
except ValueError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=400
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=500
|
||||
)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
message=f"获取部门列表失败: {str(e)}",
|
||||
code=500
|
||||
)
|
||||
|
||||
|
||||
@wechat_router.get("/user/department", summary="通过userid获取用户所在部门")
|
||||
async def api_get_user_department(userid: str):
|
||||
"""
|
||||
通过企业微信成员userid获取其所在的部门信息
|
||||
|
||||
- **userid**: 企业微信成员userid
|
||||
"""
|
||||
if not userid:
|
||||
return error_response(
|
||||
message="userid不能为空",
|
||||
code=400
|
||||
)
|
||||
|
||||
try:
|
||||
user_detail = get_user_detail(userid)
|
||||
# 提取用户的部门信息
|
||||
department_info = {
|
||||
"userid": userid,
|
||||
"departments": user_detail.get("department", []),
|
||||
"main_department": user_detail.get("main_department")
|
||||
}
|
||||
return success_response(
|
||||
data=department_info,
|
||||
message="获取用户部门信息成功",
|
||||
code=200
|
||||
)
|
||||
except ValueError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=400
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=500
|
||||
)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
message=f"获取用户部门信息失败: {str(e)}",
|
||||
code=500
|
||||
)
|
||||
|
||||
|
||||
@wechat_router.get("/customer-list", summary="获取客户列表")
|
||||
async def api_get_customer_list(userid: str):
|
||||
"""
|
||||
获取指定成员添加的客户列表
|
||||
|
||||
- **userid**: 企业成员的userid
|
||||
"""
|
||||
if not userid:
|
||||
return error_response(
|
||||
message="userid不能为空",
|
||||
code=400
|
||||
)
|
||||
|
||||
try:
|
||||
customer_list = get_customer_list(userid)
|
||||
return success_response(
|
||||
data=customer_list,
|
||||
message="获取客户列表成功",
|
||||
code=200
|
||||
)
|
||||
except ValueError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=400
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=500
|
||||
)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
message=f"获取客户列表失败: {str(e)}",
|
||||
code=500
|
||||
)
|
||||
|
||||
|
||||
@wechat_router.get("/customer-detail", summary="获取客户详情")
|
||||
async def api_get_customer_detail(external_userid: str, cursor: str = ""):
|
||||
"""
|
||||
根据外部联系人 external_userid 获取客户详情
|
||||
|
||||
- **external_userid**: 外部联系人 userid(非企业成员账号)
|
||||
- **cursor**: 分页游标,当跟进人超过500人时使用上次返回的 next_cursor
|
||||
"""
|
||||
if not external_userid:
|
||||
return error_response(
|
||||
message="external_userid不能为空",
|
||||
code=400
|
||||
)
|
||||
|
||||
try:
|
||||
customer_detail = get_external_contact_detail(external_userid, cursor)
|
||||
return success_response(
|
||||
data=customer_detail,
|
||||
message="获取客户详情成功",
|
||||
code=200
|
||||
)
|
||||
except ValueError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=400
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return error_response(
|
||||
message=str(e),
|
||||
code=500
|
||||
)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
message=f"获取客户详情失败: {str(e)}",
|
||||
code=500
|
||||
)
|
||||
|
||||
362
src/services/wechat.py
Normal file
362
src/services/wechat.py
Normal file
@@ -0,0 +1,362 @@
|
||||
import os
|
||||
import requests
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
# 企业微信 API 接口地址
|
||||
# 获取 access_token 接口
|
||||
ACCESS_URL = "http://146.56.202.222:12345/proxy/https://qyapi.weixin.qq.com/cgi-bin/gettoken"
|
||||
# 获取 userid 接口
|
||||
GET_USERID_URL = "http://146.56.202.222:12345/proxy/https://qyapi.weixin.qq.com/cgi-bin/user/getuserid"
|
||||
# 发送文本卡片消息接口
|
||||
SEND_TEXTCARD_URL = "http://146.56.202.222:12345/proxy/https://qyapi.weixin.qq.com/cgi-bin/message/send"
|
||||
|
||||
class WeChatTokenManager:
|
||||
"""
|
||||
企业微信 access_token 管理
|
||||
负责获取、缓存和刷新 access_token
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._access_token: Optional[str] = None
|
||||
self._expires_at: float = 0
|
||||
self._corp_id = os.getenv("CORP_ID")
|
||||
self._corp_secret = os.getenv("CORP_SECRET")
|
||||
|
||||
if not self._corp_id or not self._corp_secret:
|
||||
raise ValueError("环境变量 CORP_ID 和 CORP_SECRET 必须配置")
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""
|
||||
获取有效的 access_token
|
||||
如果缓存的 token 未过期则直接返回,否则重新获取
|
||||
"""
|
||||
if self._is_token_valid():
|
||||
return self._access_token
|
||||
|
||||
# 重新获取
|
||||
self._refresh_access_token()
|
||||
return self._access_token
|
||||
|
||||
def _is_token_valid(self) -> bool:
|
||||
"""检查当前缓存的 access_token 是否仍然有效"""
|
||||
return (
|
||||
self._access_token is not None and
|
||||
time.time() < self._expires_at
|
||||
)
|
||||
|
||||
def _refresh_access_token(self) -> None:
|
||||
"""
|
||||
向企业微信服务器请求新的 access_token
|
||||
并更新本地缓存
|
||||
"""
|
||||
url = ACCESS_URL
|
||||
params = {
|
||||
"corpid": self._corp_id,
|
||||
"corpsecret": self._corp_secret
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.get(url, params=params, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data: Dict = resp.json()
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"获取 access_token 网络请求失败: {e}")
|
||||
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"获取 access_token 失败: {data.get('errmsg')}")
|
||||
|
||||
self._access_token = data["access_token"]
|
||||
# 提前 5 分钟过期,避免临界时间误差
|
||||
self._expires_at = time.time() + data["expires_in"] - 300
|
||||
|
||||
# 记录日志
|
||||
# logger.info("access_token 已更新,有效期至 %s", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self._expires_at)))
|
||||
|
||||
|
||||
# 全局单例,方便业务层直接调用
|
||||
_wechat_token_manager = None
|
||||
|
||||
def _get_token_manager() -> WeChatTokenManager:
|
||||
"""
|
||||
获取或创建微信令牌管理器实例
|
||||
使用懒加载模式,避免在模块导入时就检查环境变量
|
||||
"""
|
||||
global _wechat_token_manager
|
||||
if _wechat_token_manager is None:
|
||||
_wechat_token_manager = WeChatTokenManager()
|
||||
return _wechat_token_manager
|
||||
|
||||
def get_wechat_access_token() -> str:
|
||||
"""
|
||||
业务层直接调用此函数即可获取当前有效的 access_token
|
||||
"""
|
||||
return _get_token_manager().get_access_token()
|
||||
|
||||
def get_userid_by_mobile(mobile: str) -> str:
|
||||
"""
|
||||
通过手机号获取企业微信成员 userid
|
||||
:param mobile: 成员手机号(5~32字节)
|
||||
:return: 成员 userid
|
||||
:raises: RuntimeError 当接口调用失败或返回错误时
|
||||
"""
|
||||
access_token = get_wechat_access_token()
|
||||
url = GET_USERID_URL
|
||||
params = {"access_token": access_token}
|
||||
payload = {"mobile": mobile}
|
||||
|
||||
try:
|
||||
resp = requests.post(url, params=params, json=payload, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data: Dict = resp.json()
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"通过手机号获取 userid 网络请求失败: {e}")
|
||||
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"通过手机号获取 userid 失败: {data.get('errmsg')}")
|
||||
|
||||
return data["userid"]
|
||||
|
||||
|
||||
|
||||
def send_textcard_message(
|
||||
touser: str,
|
||||
title: str,
|
||||
description: str,
|
||||
url: str,
|
||||
agentid: int = None,
|
||||
btntxt: str = "详情",
|
||||
toparty: str = "",
|
||||
totag: str = "",
|
||||
enable_id_trans: int = 0,
|
||||
enable_duplicate_check: int = 0,
|
||||
duplicate_check_interval: int = 1800
|
||||
) -> Dict:
|
||||
"""
|
||||
发送企业微信文本卡片消息
|
||||
:param touser: 成员ID列表,最多1000个,用 '|' 分隔;特殊情况填 '@all' 表示全部成员
|
||||
:param title: 标题,不超过128字符
|
||||
:param description: 描述,不超过512字符,支持div class="gray/highlight/normal"
|
||||
:param url: 点击跳转链接,需含http/https
|
||||
:param agentid: 企业应用ID,默认为环境变量AGENT_ID
|
||||
:param btntxt: 按钮文字,不超过4字符,默认"详情"
|
||||
:param toparty: 部门ID列表,最多100个,用 '|' 分隔
|
||||
:param totag: 标签ID列表,最多100个,用 '|' 分隔
|
||||
:param enable_id_trans: 是否开启ID转译,0否1是,默认0
|
||||
:param enable_duplicate_check: 是否开启重复消息检查,0否1是,默认0
|
||||
:param duplicate_check_interval: 重复检查时间间隔,秒,默认1800,最大14400
|
||||
:return: 企业微信接口返回的JSON字典
|
||||
:raises: RuntimeError 当网络或接口返回错误时
|
||||
"""
|
||||
access_token = get_wechat_access_token()
|
||||
api_url = SEND_TEXTCARD_URL
|
||||
params = {"access_token": access_token}
|
||||
|
||||
# 如果未传入 agentid,则使用环境变量 AGENT_ID
|
||||
if agentid is None:
|
||||
agentid = int(os.getenv("AGENT_ID", 0))
|
||||
|
||||
payload = {
|
||||
"touser": touser,
|
||||
"toparty": toparty,
|
||||
"totag": totag,
|
||||
"msgtype": "textcard",
|
||||
"agentid": agentid,
|
||||
"textcard": {
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
"btntxt": btntxt
|
||||
},
|
||||
"enable_id_trans": enable_id_trans,
|
||||
"enable_duplicate_check": enable_duplicate_check,
|
||||
"duplicate_check_interval": duplicate_check_interval
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(api_url, params=params, json=payload, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data: Dict = resp.json()
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"发送文本卡片消息网络请求失败: {e}")
|
||||
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"发送文本卡片消息失败: {data.get('errmsg')}")
|
||||
|
||||
return data
|
||||
|
||||
# 获取访问用户身份接口
|
||||
GET_USER_INFO_URL = "http://146.56.202.222:12345/proxy/https://qyapi.weixin.qq.com/cgi-bin/auth/getuserinfo"
|
||||
|
||||
def get_userinfo_by_code(code: str) -> Dict:
|
||||
"""
|
||||
通过成员授权获取到的 code 获取用户登陆身份
|
||||
:param code: 成员授权获取到的 code,最大 512 字节,只能使用一次,5 分钟未被使用自动过期
|
||||
:return: 企业微信接口返回的 JSON 字典,包含 userid / external_userid / openid 等字段
|
||||
:raises: RuntimeError 当网络或接口返回错误时
|
||||
"""
|
||||
access_token = get_wechat_access_token()
|
||||
url = GET_USER_INFO_URL
|
||||
params = {
|
||||
"access_token": access_token,
|
||||
"code": code
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.get(url, params=params, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data: Dict = resp.json()
|
||||
print(f"企业微信通过 code 获取用户身份接口原始响应: {resp.text}") # 打印原始响应
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"通过 code 获取用户身份网络请求失败: {e}")
|
||||
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"通过 code 获取用户身份失败: {data.get('errmsg')}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def authenticate_wechat_user(code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
企业微信用户认证流程
|
||||
1. 通过code获取用户信息(仅返回userid)
|
||||
2. 如果需要,可以在数据库中创建或更新用户记录
|
||||
3. 返回用户信息(用于生成token)
|
||||
|
||||
Args:
|
||||
code: 企业微信授权code
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 用户信息字典,仅包含userid
|
||||
"""
|
||||
# 1. 通过code获取企业微信用户信息(仅userid)
|
||||
user_info = get_userinfo_by_code(code)
|
||||
|
||||
return {
|
||||
"userid": user_info.get("userid"),
|
||||
}
|
||||
|
||||
# 获取部门列表接口
|
||||
GET_DEPARTMENT_LIST_URL = "http://146.56.202.222:12345/proxy/https://qyapi.weixin.qq.com/cgi-bin/department/simplelist"
|
||||
|
||||
def get_department_list(id: Optional[int] = None) -> Dict:
|
||||
"""
|
||||
获取企业微信部门列表
|
||||
:param id: 部门id(可选)。获取指定部门及其下的子部门(递归),不填则获取全量组织架构
|
||||
:return: 企业微信接口返回的JSON字典,包含部门列表(department_id字段)
|
||||
:raises: RuntimeError 当接口调用失败或返回错误时
|
||||
"""
|
||||
access_token = get_wechat_access_token()
|
||||
url = GET_DEPARTMENT_LIST_URL
|
||||
params = {"access_token": access_token}
|
||||
|
||||
# 如果提供了部门id,则添加到请求参数中
|
||||
if id is not None:
|
||||
params["id"] = id
|
||||
|
||||
try:
|
||||
resp = requests.get(url, params=params, timeout=10)
|
||||
resp.raise_for_status()
|
||||
print(f"企业微信部门列表接口原始响应: {resp.text}") # 打印原始响应
|
||||
data: Dict = resp.json()
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"获取部门列表网络请求失败: {e}")
|
||||
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"获取部门列表失败: {data.get('errmsg')}")
|
||||
|
||||
return data
|
||||
# 读取成员详情接口(仅返回应用可见字段)
|
||||
GET_USER_DETAIL_URL = "http://146.56.202.222:12345/proxy/https://qyapi.weixin.qq.com/cgi-bin/user/get"
|
||||
|
||||
def get_user_detail(userid: str) -> Dict:
|
||||
"""
|
||||
通过 userid 获取企业微信成员详情(仅返回应用可见字段)
|
||||
:param userid: 成员 UserID
|
||||
:return: 企业微信接口返回的 JSON 字典
|
||||
:raises: RuntimeError 当网络或接口返回错误时
|
||||
"""
|
||||
access_token = get_wechat_access_token()
|
||||
url = GET_USER_DETAIL_URL
|
||||
params = {
|
||||
"access_token": access_token,
|
||||
"userid": userid
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.get(url, params=params, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data: Dict = resp.json()
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"获取成员详情网络请求失败: {e}")
|
||||
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"获取成员详情失败: {data.get('errmsg')}")
|
||||
|
||||
# 调试打印:仅打印非敏感字段,避免泄露
|
||||
print(f"[DEBUG] get_user_detail 原始响应: {data}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
||||
|
||||
# 获取客户列表接口
|
||||
GET_CUSTOMER_LIST_URL = "http://146.56.202.222:12345/proxy/https://qyapi.weixin.qq.com/cgi-bin/externalcontact/list"
|
||||
|
||||
def get_customer_list(userid: str) -> Dict:
|
||||
"""
|
||||
获取指定成员添加的客户列表
|
||||
:param userid: 企业成员的userid
|
||||
:return: 企业微信接口返回的JSON字典,包含external_userid列表
|
||||
:raises: RuntimeError 当接口调用失败或返回错误时
|
||||
"""
|
||||
access_token = get_wechat_access_token()
|
||||
url = GET_CUSTOMER_LIST_URL
|
||||
params = {
|
||||
"access_token": access_token,
|
||||
"userid": userid
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.get(url, params=params, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data: Dict = resp.json()
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"获取客户列表网络请求失败: {e}")
|
||||
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"获取客户列表失败: {data.get('errmsg')}")
|
||||
|
||||
return data
|
||||
|
||||
# 获取外部联系人详情接口
|
||||
GET_EXTERNAL_CONTACT_DETAIL_URL = "http://146.56.202.222:12345/proxy/https://qyapi.weixin.qq.com/cgi-bin/externalcontact/get"
|
||||
|
||||
def get_external_contact_detail(external_userid: str, cursor: str = "") -> Dict:
|
||||
"""
|
||||
根据外部联系人 external_userid 获取客户详情
|
||||
:param external_userid: 外部联系人 userid(非企业成员账号)
|
||||
:param cursor: 分页游标,当跟进人超过500人时使用上次返回的 next_cursor
|
||||
:return: 企业微信接口返回的 JSON 字典,包含 external_contact、follow_user 及 next_cursor
|
||||
:raises: RuntimeError 当网络或接口返回错误时
|
||||
"""
|
||||
access_token = get_wechat_access_token()
|
||||
url = GET_EXTERNAL_CONTACT_DETAIL_URL
|
||||
params = {
|
||||
"access_token": access_token,
|
||||
"external_userid": external_userid
|
||||
}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
try:
|
||||
resp = requests.get(url, params=params, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data: Dict = resp.json()
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"获取外部联系人详情网络请求失败: {e}")
|
||||
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"获取外部联系人详情失败: {data.get('errmsg')}")
|
||||
|
||||
return data
|
||||
167
src/utils/auth.py
Normal file
167
src/utils/auth.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
import jwt
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 获取JWT配置
|
||||
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "secret-key")
|
||||
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
|
||||
JWT_EXPIRE_MINUTES = int(os.getenv("JWT_EXPIRE_MINUTES", "1440")) # 默认24小时
|
||||
|
||||
|
||||
class JWTTokenManager:
|
||||
"""
|
||||
JWT令牌管理器,用于生成和验证JWT令牌
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
生成JWT令牌
|
||||
|
||||
Args:
|
||||
data: 要编码到令牌中的数据
|
||||
expires_delta: 令牌过期时间,如果不提供则使用默认值
|
||||
|
||||
Returns:
|
||||
str: 生成的JWT令牌
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=JWT_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
@staticmethod
|
||||
def verify_token(token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
验证JWT令牌
|
||||
|
||||
Args:
|
||||
token: 要验证的JWT令牌
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 解码后的令牌数据
|
||||
|
||||
Raises:
|
||||
jwt.PyJWTError: 当令牌无效或过期时抛出
|
||||
"""
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
return payload
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
token_manager = JWTTokenManager()
|
||||
|
||||
|
||||
def create_access_token(subject: str, additional_claims: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
创建访问令牌
|
||||
|
||||
Args:
|
||||
subject: 令牌的主题(通常是用户ID)
|
||||
additional_claims: 要添加到令牌中的额外声明
|
||||
|
||||
Returns:
|
||||
str: 生成的访问令牌
|
||||
"""
|
||||
claims = {"sub": subject}
|
||||
if additional_claims:
|
||||
claims.update(additional_claims)
|
||||
|
||||
return token_manager.create_token(claims)
|
||||
|
||||
|
||||
def verify_access_token(token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
验证访问令牌
|
||||
|
||||
Args:
|
||||
token: 要验证的访问令牌
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 解码后的令牌数据
|
||||
|
||||
Raises:
|
||||
jwt.PyJWTError: 当令牌无效或过期时抛出
|
||||
"""
|
||||
return token_manager.verify_token(token)
|
||||
|
||||
|
||||
def get_current_user_id(token: str) -> str:
|
||||
"""
|
||||
从令牌中获取当前用户ID
|
||||
|
||||
Args:
|
||||
token: 访问令牌
|
||||
|
||||
Returns:
|
||||
str: 用户ID
|
||||
|
||||
Raises:
|
||||
jwt.PyJWTError: 当令牌无效或过期时抛出
|
||||
KeyError: 当令牌中没有sub字段时抛出
|
||||
"""
|
||||
payload = verify_access_token(token)
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
raise KeyError("Token is missing 'sub' claim")
|
||||
|
||||
return user_id
|
||||
|
||||
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
# 创建Bearer令牌安全方案
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Security(bearer_scheme)) -> Dict[str, Any]:
|
||||
"""
|
||||
FastAPI依赖函数,用于获取当前用户信息
|
||||
从Authorization头中提取Bearer令牌并验证
|
||||
|
||||
Args:
|
||||
credentials: HTTPAuthorizationCredentials对象,包含令牌信息
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 解码后的令牌数据
|
||||
|
||||
Raises:
|
||||
HTTPException: 当令牌无效、过期或缺失时抛出401错误
|
||||
"""
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="未提供授权令牌",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
token = credentials.credentials
|
||||
payload = verify_access_token(token)
|
||||
return payload
|
||||
except jwt.PyJWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效或过期的令牌",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="认证失败",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
@@ -1,15 +1,3 @@
|
||||
"""
|
||||
日志工具模块
|
||||
|
||||
该模块提供了完整的日志记录功能,包括:
|
||||
1. 控制台日志输出
|
||||
2. 数据库日志存储
|
||||
3. 异常信息捕获和记录
|
||||
4. 不同日志级别的记录函数
|
||||
|
||||
日志信息会被同时输出到控制台和存储到数据库中,便于问题排查和系统监控。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
Reference in New Issue
Block a user