Refactor lifespan with ChainBuilder for initialization tasks
- Replace manual initialization sequence with ChainBuilder pattern - Add ChainBuilder class supporting both sync and async task chaining - Rename test_init() to data_base_init() for clarity - Fix string formatting in log messages (remove f-string where unnecessary) - Fix escape sequence in department schema documentation - Convert CheckinType to Enum for better type safety
This commit is contained in:
130
lifespan.py
130
lifespan.py
@@ -1,12 +1,103 @@
|
|||||||
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from uvicorn.server import logger
|
from uvicorn.server import logger
|
||||||
|
|
||||||
|
|
||||||
async def test_init():
|
class ChainBuilder:
|
||||||
from service.sync.department import sync_department, check_department_datebase
|
"""支持链式调用的建造者类,支持同步和异步方法"""
|
||||||
from service.sync.employee import sync_department_user, check_employee_datebase
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tasks = []
|
||||||
|
|
||||||
|
def add(self, func, *args, **kwargs):
|
||||||
|
"""添加同步或异步任务到链中"""
|
||||||
|
self._tasks.append((func, args, kwargs))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def adds(self, *funcs_or_tuples):
|
||||||
|
"""添加一个或多个同步或异步任务到链中
|
||||||
|
|
||||||
|
支持多种调用方式:
|
||||||
|
1. 单个函数: add(func, *args, **kwargs)
|
||||||
|
2. 多个函数: add((func1, args1, kwargs1), (func2, args2, kwargs2), ...)
|
||||||
|
3. 混合方式: add(func1, *args1, **kwargs1), (func2, args2, kwargs2), ...
|
||||||
|
"""
|
||||||
|
for item in funcs_or_tuples:
|
||||||
|
if isinstance(item, tuple) and len(item) == 3:
|
||||||
|
# 如果是三元组 (func, args, kwargs)
|
||||||
|
func, args, kwargs = item
|
||||||
|
self._tasks.append((func, args, kwargs))
|
||||||
|
elif callable(item):
|
||||||
|
# 如果是单个函数,需要检查后续参数
|
||||||
|
if (
|
||||||
|
len(funcs_or_tuples) >= 3
|
||||||
|
and isinstance(funcs_or_tuples[1], tuple)
|
||||||
|
and isinstance(funcs_or_tuples[2], dict)
|
||||||
|
):
|
||||||
|
# 如果是 add(func, args, kwargs) 格式
|
||||||
|
func = item
|
||||||
|
args = funcs_or_tuples[1] if len(funcs_or_tuples) > 1 else ()
|
||||||
|
kwargs = funcs_or_tuples[2] if len(funcs_or_tuples) > 2 else {}
|
||||||
|
self._tasks.append((func, args, kwargs))
|
||||||
|
break # 处理完这个函数后退出循环
|
||||||
|
else:
|
||||||
|
# 单个函数没有参数
|
||||||
|
self._tasks.append((item, (), {}))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的参数类型: {type(item)}")
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def _execute_async(self):
|
||||||
|
"""异步执行所有任务"""
|
||||||
|
for func, args, kwargs in self._tasks:
|
||||||
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
await func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
# 如果是同步函数,在事件循环中运行
|
||||||
|
func(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
"""同步调用接口"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# 检查是否有异步任务
|
||||||
|
has_async = any(asyncio.iscoroutinefunction(func) for func, _, _ in self._tasks)
|
||||||
|
|
||||||
|
if has_async:
|
||||||
|
# 如果有异步任务,创建并运行事件循环
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
# 如果事件循环已经在运行,创建任务
|
||||||
|
asyncio.create_task(self._execute_async())
|
||||||
|
else:
|
||||||
|
# 否则运行事件循环
|
||||||
|
loop.run_until_complete(self._execute_async())
|
||||||
|
else:
|
||||||
|
# 如果都是同步任务,直接执行
|
||||||
|
for func, args, kwargs in self._tasks:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""异步上下文管理器入口"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""异步上下文管理器出口"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_async(self):
|
||||||
|
"""显式异步执行方法"""
|
||||||
|
await self._execute_async()
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
async def data_base_init():
|
||||||
|
from service.sync.department import check_department_datebase, sync_department
|
||||||
|
from service.sync.employee import check_employee_datebase, sync_department_user
|
||||||
|
|
||||||
if not check_department_datebase():
|
if not check_department_datebase():
|
||||||
logger.info("[数据库] 开始同步部门 📦")
|
logger.info("[数据库] 开始同步部门 📦")
|
||||||
@@ -35,34 +126,41 @@ def init_scheduler(app: FastAPI):
|
|||||||
|
|
||||||
|
|
||||||
def active_config():
|
def active_config():
|
||||||
logger.info(f"[激活配置] 加载配置 ⚙️")
|
logger.info("[激活配置] 加载配置 ⚙️")
|
||||||
from config import Settings # noqa
|
from config import Settings # noqa
|
||||||
|
|
||||||
|
|
||||||
def import_router(app: FastAPI):
|
def import_router(app: FastAPI):
|
||||||
logger.info(f"[导入路由] 开始导入路由 🛣️")
|
logger.info("[导入路由] 开始导入路由 🛣️")
|
||||||
from router import router
|
from router import router
|
||||||
|
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
logger.info(f"[导入路由] 路由导入完成 ✅")
|
logger.info("[导入路由] 路由导入完成 ✅")
|
||||||
|
|
||||||
|
|
||||||
async def import_mcp_server(app: FastAPI):
|
async def import_mcp_server(app: FastAPI):
|
||||||
logger.info(f"[导入MCP] 开始导入MCP 🛣️")
|
logger.info("[导入MCP] 开始导入MCP 🛣️")
|
||||||
from mcps import create_mcp_app
|
from mcps import create_mcp_app
|
||||||
|
|
||||||
app.mount("/app", await create_mcp_app())
|
app.mount("/app", await create_mcp_app())
|
||||||
logger.info(f"[导入MCP] MCP导入完成 ✅")
|
logger.info("[导入MCP] MCP导入完成 ✅")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
logger.info(f"[生命周期] 应用启动 🚀")
|
logger.info("[生命周期] 应用启动 🚀")
|
||||||
active_config()
|
builder = ChainBuilder()
|
||||||
init_database()
|
|
||||||
import_router(app)
|
# 激活配置
|
||||||
init_scheduler(app)
|
builder.add(active_config)
|
||||||
await import_mcp_server(app)
|
# 初始化数据库
|
||||||
await test_init()
|
builder.add(init_database).add(data_base_init)
|
||||||
|
# 导入MCP
|
||||||
|
builder.add(import_mcp_server, app)
|
||||||
|
# 导入路由
|
||||||
|
builder.add(import_router, app)
|
||||||
|
# 初始化定时任务
|
||||||
|
builder.add(init_scheduler, app)
|
||||||
|
await builder.execute_async()
|
||||||
yield
|
yield
|
||||||
logger.info(f"[生命周期] 应用关闭 🔧✅")
|
logger.info("[生命周期] 应用关闭 🔧✅")
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from service.wecom.schemas.base import BaseSchema
|
from service.wecom.schemas.base import BaseSchema
|
||||||
|
|
||||||
|
|
||||||
class CheckinType:
|
class CheckinType(Enum):
|
||||||
"""打卡类型"""
|
"""打卡类型"""
|
||||||
|
|
||||||
ON_OFF_DUTY = 1 # 上下班打卡
|
ON_OFF_DUTY = 1 # 上下班打卡
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ class CreateDepartmentParams(BaseSchema):
|
|||||||
"""
|
"""
|
||||||
创建部门
|
创建部门
|
||||||
|
|
||||||
@param name: 部门名称。长度限制为1~32个字节,字符不能包括\:?”<>
|
@param name: 部门名称。长度限制为1~32个字节,字符不能包括\\:?”<>
|
||||||
@param name_en: 英文名称
|
@param name_en: 英文名称
|
||||||
@param parentid: 父部门id。根部门id为1
|
@param parentid: 父部门id。根部门id为1
|
||||||
@param order: 在父部门中的次序值。order值小的排序靠前。
|
@param order: 在父部门中的次序值。order值小的排序靠前。
|
||||||
|
|||||||
Reference in New Issue
Block a user