- Replace manual initialization sequence with ChainBuilder pattern - Add ChainBuilder class supporting both sync and async task chaining - Rename test_init() to data_base_init() for clarity - Fix string formatting in log messages (remove f-string where unnecessary) - Fix escape sequence in department schema documentation - Convert CheckinType to Enum for better type safety
167 lines
5.5 KiB
Python
167 lines
5.5 KiB
Python
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI
|
|
from uvicorn.server import logger
|
|
|
|
|
|
class ChainBuilder:
|
|
"""支持链式调用的建造者类,支持同步和异步方法"""
|
|
|
|
def __init__(self):
|
|
self._tasks = []
|
|
|
|
def add(self, func, *args, **kwargs):
|
|
"""添加同步或异步任务到链中"""
|
|
self._tasks.append((func, args, kwargs))
|
|
return self
|
|
|
|
def adds(self, *funcs_or_tuples):
|
|
"""添加一个或多个同步或异步任务到链中
|
|
|
|
支持多种调用方式:
|
|
1. 单个函数: add(func, *args, **kwargs)
|
|
2. 多个函数: add((func1, args1, kwargs1), (func2, args2, kwargs2), ...)
|
|
3. 混合方式: add(func1, *args1, **kwargs1), (func2, args2, kwargs2), ...
|
|
"""
|
|
for item in funcs_or_tuples:
|
|
if isinstance(item, tuple) and len(item) == 3:
|
|
# 如果是三元组 (func, args, kwargs)
|
|
func, args, kwargs = item
|
|
self._tasks.append((func, args, kwargs))
|
|
elif callable(item):
|
|
# 如果是单个函数,需要检查后续参数
|
|
if (
|
|
len(funcs_or_tuples) >= 3
|
|
and isinstance(funcs_or_tuples[1], tuple)
|
|
and isinstance(funcs_or_tuples[2], dict)
|
|
):
|
|
# 如果是 add(func, args, kwargs) 格式
|
|
func = item
|
|
args = funcs_or_tuples[1] if len(funcs_or_tuples) > 1 else ()
|
|
kwargs = funcs_or_tuples[2] if len(funcs_or_tuples) > 2 else {}
|
|
self._tasks.append((func, args, kwargs))
|
|
break # 处理完这个函数后退出循环
|
|
else:
|
|
# 单个函数没有参数
|
|
self._tasks.append((item, (), {}))
|
|
else:
|
|
raise ValueError(f"不支持的参数类型: {type(item)}")
|
|
return self
|
|
|
|
async def _execute_async(self):
|
|
"""异步执行所有任务"""
|
|
for func, args, kwargs in self._tasks:
|
|
if asyncio.iscoroutinefunction(func):
|
|
await func(*args, **kwargs)
|
|
else:
|
|
# 如果是同步函数,在事件循环中运行
|
|
func(*args, **kwargs)
|
|
|
|
def __call__(self):
|
|
"""同步调用接口"""
|
|
import asyncio
|
|
|
|
# 检查是否有异步任务
|
|
has_async = any(asyncio.iscoroutinefunction(func) for func, _, _ in self._tasks)
|
|
|
|
if has_async:
|
|
# 如果有异步任务,创建并运行事件循环
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
# 如果事件循环已经在运行,创建任务
|
|
asyncio.create_task(self._execute_async())
|
|
else:
|
|
# 否则运行事件循环
|
|
loop.run_until_complete(self._execute_async())
|
|
else:
|
|
# 如果都是同步任务,直接执行
|
|
for func, args, kwargs in self._tasks:
|
|
func(*args, **kwargs)
|
|
|
|
return self
|
|
|
|
async def __aenter__(self):
|
|
"""异步上下文管理器入口"""
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
"""异步上下文管理器出口"""
|
|
pass
|
|
|
|
async def execute_async(self):
|
|
"""显式异步执行方法"""
|
|
await self._execute_async()
|
|
return self
|
|
|
|
|
|
async def data_base_init():
|
|
from service.sync.department import check_department_datebase, sync_department
|
|
from service.sync.employee import check_employee_datebase, sync_department_user
|
|
|
|
if not check_department_datebase():
|
|
logger.info("[数据库] 开始同步部门 📦")
|
|
await sync_department()
|
|
logger.info("[数据库] 同步部门完成 📦")
|
|
if not check_employee_datebase():
|
|
logger.info("[数据库] 开始同步员工 📦")
|
|
await sync_department_user()
|
|
logger.info("[数据库] 同步员工完成 📦")
|
|
|
|
|
|
def init_database():
|
|
from model import create_db_and_tables
|
|
|
|
logger.info("[数据库] 初始化数据库 📦")
|
|
create_db_and_tables()
|
|
logger.info("[数据库] 数据库初始化完成 ✅")
|
|
|
|
|
|
def init_scheduler(app: FastAPI):
|
|
from scheduler import init_scheduler_router
|
|
|
|
logger.info("[定时任务] 初始化定时任务 📦")
|
|
init_scheduler_router(app)
|
|
logger.info("[定时任务] 定时任务初始化完成 ✅")
|
|
|
|
|
|
def active_config():
|
|
logger.info("[激活配置] 加载配置 ⚙️")
|
|
from config import Settings # noqa
|
|
|
|
|
|
def import_router(app: FastAPI):
|
|
logger.info("[导入路由] 开始导入路由 🛣️")
|
|
from router import router
|
|
|
|
app.include_router(router)
|
|
logger.info("[导入路由] 路由导入完成 ✅")
|
|
|
|
|
|
async def import_mcp_server(app: FastAPI):
|
|
logger.info("[导入MCP] 开始导入MCP 🛣️")
|
|
from mcps import create_mcp_app
|
|
|
|
app.mount("/app", await create_mcp_app())
|
|
logger.info("[导入MCP] MCP导入完成 ✅")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info("[生命周期] 应用启动 🚀")
|
|
builder = ChainBuilder()
|
|
|
|
# 激活配置
|
|
builder.add(active_config)
|
|
# 初始化数据库
|
|
builder.add(init_database).add(data_base_init)
|
|
# 导入MCP
|
|
builder.add(import_mcp_server, app)
|
|
# 导入路由
|
|
builder.add(import_router, app)
|
|
# 初始化定时任务
|
|
builder.add(init_scheduler, app)
|
|
await builder.execute_async()
|
|
yield
|
|
logger.info("[生命周期] 应用关闭 🔧✅")
|