Initial commit

This commit is contained in:
2026-02-25 15:22:23 +08:00
commit c7138dab9e
84 changed files with 14690 additions and 0 deletions

17
plugin/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
from .base import PluginManager
from .cors import CORSMiddleware
from .monitor import Monitor
from .profiler import Profiler
from .spa import SpaProxy
from .user import AuthUser, UserManager, UserSource
__all__ = [
"PluginManager",
"Monitor",
"SpaProxy",
"CORSMiddleware",
"AuthUser",
"UserManager",
"UserSource",
"Profiler",
]

34
plugin/base.py Normal file
View File

@@ -0,0 +1,34 @@
from fastapi import FastAPI
from uvicorn.server import logger
class Plugin:
def __init__(
self, app: FastAPI, name="Unnamed Plugin", version="1.0.0", *args, **kwargs
):
self.app = app
self.name = name
self.version = version
def install(self):
pass
class PluginManager:
def __init__(self, app: FastAPI):
self.app = app
def register_plugin(self, plugin: Plugin):
plugin_name = getattr(plugin, "name", "Unnamed Plugin")
plugin_version = getattr(plugin, "version", "Unknown Version")
logger.info(f"[插件] Registering plugin: [{plugin_name} {plugin_version}] ")
try:
plugin.install()
logger.info(
f"[插件] Plugin [{plugin_name} {plugin_version}] installed successfully ✅"
)
except Exception as e:
logger.error(
f"[插件] Failed to install plugin [{plugin_name} {plugin_version}]: {e}"
)

3
plugin/cors/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .cors import CorsPlugin as CORSMiddleware
__all__ = ["CORSMiddleware"]

23
plugin/cors/cors.py Normal file
View File

@@ -0,0 +1,23 @@
from fastapi import FastAPI
from plugin.base import Plugin
class CorsPlugin(Plugin):
def __init__(self, app: FastAPI):
self.name = "CorsPlugin"
self.description = "Enable Cross-Origin Resource Sharing (CORS)"
self.version = "1.0.0"
self.app = app
def install(self):
from fastapi.middleware.cors import CORSMiddleware
# 添加CORS中间件解决跨域问题
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

View File

@@ -0,0 +1,3 @@
from .middleware import Monitor
__all__ = ["Monitor"]

View File

@@ -0,0 +1,267 @@
import asyncio
import time
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from plugin.base import Plugin
from uvicorn.server import logger
class APIMonitorMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: FastAPI,
*,
entity: object = None,
exclude_paths: Optional[List[str]] = None,
rate_limit_window: int = 60,
):
super().__init__(app)
self.exclude_paths = exclude_paths or []
self.rate_limit_window = rate_limit_window
# 用于存储请求频率数据
self.request_counts = defaultdict(int)
self.request_timestamps = defaultdict(list)
# 存储中间件实例以便后续访问
# app.state.api_monitor = self
entity.middleware = self
# 启动后台清理任务
self.cleanup_task = asyncio.create_task(self._cleanup_old_records())
async def _cleanup_old_records(self):
"""定期清理过期的请求记录"""
try:
while True:
await asyncio.sleep(self.rate_limit_window * 2)
current_time = time.time()
keys_to_remove = []
for key in list(self.request_timestamps.keys()):
# 移除超过时间窗口的时间戳
self.request_timestamps[key] = [
ts
for ts in self.request_timestamps[key]
if current_time - ts < self.rate_limit_window
]
if not self.request_timestamps[key]:
keys_to_remove.append(key)
for key in keys_to_remove:
del self.request_timestamps[key]
if key in self.request_counts:
del self.request_counts[key]
except asyncio.CancelledError:
# 任务被取消时正常退出
pass
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 检查是否排除该路径
if any(request.url.path.startswith(path) for path in self.exclude_paths):
return await call_next(request)
# 记录开始时间
start_time = time.time()
try:
# 处理请求
response = await call_next(request)
process_time = time.time() - start_time
status_code = response.status_code
# response.headers["X-Process-Time"] = str(process_time)
error = None
except Exception as e:
process_time = time.time() - start_time
response = None
status_code = 500
error = str(e)
# 重新抛出异常以便FastAPI处理
raise
finally:
# 无论是否异常都记录日志
# 更新请求频率统计
await self._update_request_stats(request, status_code)
# 记录请求信息
await self._log_request(
request, status_code, start_time, process_time, error
)
return response if response else Response(status_code=status_code)
async def _update_request_stats(self, request: Request, status_code: int):
"""更新请求频率统计"""
client_ip = request.client.host if request.client else "unknown"
endpoint_key = f"{client_ip}:{request.method}:{request.url.path}"
current_time = time.time()
self.request_timestamps[endpoint_key].append(current_time)
# 计算当前时间窗口内的请求次数
window_start = current_time - self.rate_limit_window
self.request_timestamps[endpoint_key] = [
ts for ts in self.request_timestamps[endpoint_key] if ts >= window_start
]
self.request_counts[endpoint_key] = len(self.request_timestamps[endpoint_key])
async def _log_request(
self,
request: Request,
status_code: int,
start_time: float,
process_time: float = 0,
error: Optional[str] = None,
):
"""记录请求日志"""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
referer = request.headers.get("referer", "")
log_data = {
"timestamp": datetime.fromtimestamp(start_time).isoformat(),
"client_ip": client_ip,
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
# "body": await request.json(),
"status_code": status_code,
"process_time_ms": round(process_time * 1000, 2),
"user_agent": user_agent[:200], # 限制长度
"referer": referer[:200],
"content_type": request.headers.get("content-type", ""),
"content_length": int(request.headers.get("content-length", 0)),
"error": error,
}
# 拼接为一条字符串
log_entry = f"耗时:{log_data['process_time_ms']}ms,错误:{error}, 请求参数: {str(log_data['query_params'])}, 请求url: {log_data['url']} "
# 根据状态码决定日志级别
if status_code >= 500:
logger.error(log_entry)
elif status_code >= 400:
logger.error(log_entry)
elif status_code >= 300:
logger.info(log_entry)
elif status_code >= 200:
logger.debug(log_entry)
def get_request_stats(
self, ip: Optional[str] = None, path: Optional[str] = None
) -> Dict[str, Any]:
"""获取请求统计信息"""
stats = {}
current_time = time.time()
for key in list(self.request_counts.keys()):
parts = key.split(":", 2)
if len(parts) < 3:
continue
client_ip, method, endpoint_path = parts
if ip and client_ip != ip:
continue
if path and endpoint_path != path:
continue
# 计算各种统计指标
timestamps = self.request_timestamps.get(key, [])
recent_requests = [
ts for ts in timestamps if current_time - ts < self.rate_limit_window
]
stats[key] = {
"total_requests": len(timestamps),
"current_window_requests": len(recent_requests),
"requests_per_minute": len(recent_requests)
* (60 / self.rate_limit_window),
"last_request_time": datetime.fromtimestamp(max(timestamps)).isoformat()
if timestamps
else None,
"method": method,
"endpoint": endpoint_path,
"client_ip": client_ip,
}
return stats
def get_summary_stats(self) -> Dict[str, Any]:
"""获取汇总统计信息"""
total_requests = sum(
len(timestamps) for timestamps in self.request_timestamps.values()
)
unique_clients = len(
set(key.split(":")[0] for key in self.request_timestamps.keys())
)
unique_endpoints = len(
set(
key.split(":")[2]
for key in self.request_timestamps.keys()
if len(key.split(":")) >= 3
)
)
return {
"total_requests": total_requests,
"unique_clients": unique_clients,
"unique_endpoints": unique_endpoints,
"monitoring_since": datetime.now().isoformat(),
"time_window_seconds": self.rate_limit_window,
}
class Monitor(Plugin):
def __init__(
self, app: FastAPI, exclude_paths: list = None, rate_limit_window: int = 60
):
"""_summary_
Args:
app (FastAPI): _description_
exclude_paths (list, optional): _description_. Defaults to None.
rate_limit_window (int, optional): _description_. Defaults to 60.
"""
logger.info("[监控插件] 加载监控插件 🔧")
self.app = app
self.exclude_paths = exclude_paths
self.rate_limit_window = rate_limit_window
self.middleware = None
self.name = "Monitor"
self.version = "1.0.0"
async def get_stats(self, ip: str = None, path: str = None):
"""获取API统计信息"""
if self.middleware:
return {
"summary": self.middleware.get_summary_stats(),
"detailed_stats": self.middleware.get_request_stats(ip, path),
}
return {"error": "Monitor middleware not initialized"}
def install(self):
self.app.add_middleware(
APIMonitorMiddleware,
entity=self,
exclude_paths=self.exclude_paths,
rate_limit_window=self.rate_limit_window,
)
self.app.add_api_route(
path="/api/stats",
endpoint=self.get_stats,
methods=["GET"],
summary="获取API统计信息",
description="获取API统计信息",
)
logger.info("[监控插件] 监控插件已加载 ✅")

View File

@@ -0,0 +1,3 @@
from .profiler import Profiler
__all__ = ["Profiler"]

View File

@@ -0,0 +1,15 @@
from fastapi import FastAPI
from fastapi_profiler import Profiler as FastapiProfilerMiddleware
from plugin.base import Plugin
class Profiler(Plugin):
def __init__(self, app: FastAPI, dashboard_path: str = "/profiler"):
self.app = app
self.dashboard_path = dashboard_path
self.name = "Profiler"
self.version = "1.0.0"
def install(self):
FastapiProfilerMiddleware(self.app, self.dashboard_path)

3
plugin/spa/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .spa_proxy import SpaProxy
__all__ = ["SpaProxy"]

50
plugin/spa/spa_proxy.py Normal file
View File

@@ -0,0 +1,50 @@
import os
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from plugin.base import Plugin
from uvicorn.server import logger
class SpaProxy(Plugin):
def __init__(self, app: FastAPI, dist: str = "dist"):
self.app = app
self.dist = dist
self.name = "SPA Proxy"
self.version = "1.0.0"
def install(self):
# --- [新增] 挂载前端静态文件 --
frontend_dir = os.path.join(os.getcwd(), self.dist)
if not os.path.isdir(frontend_dir):
return
# 1. 挂载静态资源(如 /assets/...
assets_dir = os.path.join(frontend_dir, "assets")
if os.path.isdir(assets_dir):
self.app.mount("/assets", StaticFiles(directory=assets_dir), name="assets")
# 2. 【关键】兜底路由:处理所有未被 API 匹配的路径(支持前端路由如 /analysis
@self.app.get("/{full_path:path}")
async def serve_frontend(full_path: str):
# 如果请求的是 API 路径,不应走到这里(因为 API 路由已先注册)
index_path = os.path.join(frontend_dir, "index.html")
if os.path.isfile(index_path):
return FileResponse(index_path)
raise HTTPException(status_code=404, detail="Frontend not found")
# 3. 显式处理根路径(可选,但更清晰)
@self.app.get("/")
async def root():
return await serve_frontend("")
# self.app.add_api_route(
# "/{full_path:path}",
# serve_frontend,
# methods=["GET"],
# response_class=FileResponse,
# )
logger.info(f"前端静态文件挂载: {frontend_dir}")

3
plugin/user/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .user_manager import AuthUser, UserManager, UserSource
__all__ = ["AuthUser", "UserManager", "UserSource"]

190
plugin/user/user_manager.py Normal file
View File

@@ -0,0 +1,190 @@
import datetime
import secrets
import uuid
from typing import Generic, TypedDict, TypeVar
import jwt
from fastapi import HTTPException, Request
from pydantic import BaseModel
from typing_extensions import List
ID_TYPE = str | int
class AuthUser(BaseModel):
id: ID_TYPE
def __init__(self, *args, **kwargs):
if "id" not in kwargs:
kwargs["id"] = str(uuid.uuid4())
super().__init__(*args, **kwargs)
T = TypeVar("T")
class Paginate(TypedDict, Generic[T]):
items: list[T]
total: int
pages: int
current_page: int
per_page: int
class UserSource(Generic[T]):
def __init__(self, *args, **kwargs) -> None:
self._data: dict[ID_TYPE, T] = dict[ID_TYPE, T](*args, **kwargs)
def __getitem__(self, key: ID_TYPE) -> T:
return self._data[key]
def __setitem__(self, key: ID_TYPE, value: T) -> None:
self._data[key] = value
def __delitem__(self, key: ID_TYPE) -> None:
del self._data[key]
def __contains__(self, key: ID_TYPE) -> bool:
return key in self._data
def keys(self) -> List[ID_TYPE]:
return list(self._data.keys())
def get(self, key: ID_TYPE, default: T | None = None) -> T | None:
return self._data.get(key, default)
def paginate(self, page: int = 1, per_page: int = 10) -> Paginate[T]:
"""分页获取用户数据
Args:
page: 页码从1开始
per_page: 每页数量
Returns:
包含分页信息的字典,包括:
- items: 当前页的数据列表
- total: 总数据量
- pages: 总页数
- current_page: 当前页码
- per_page: 每页数量
"""
total = len(self._data)
pages = (total + per_page - 1) // per_page if per_page > 0 else 0
if page < 1:
page = 1
if page > pages and pages > 0:
page = pages
start = (page - 1) * per_page
end = start + per_page
items = list(self._data.values())[start:end]
return {
"items": items,
"total": total,
"pages": pages,
"current_page": page,
"per_page": per_page,
}
def __repr__(self) -> str:
return f"UserSource({self._data})"
TUser = TypeVar("TUser", bound=AuthUser)
class UserManager(Generic[TUser]):
users: UserSource[TUser] = UserSource()
secret_key: str = secrets.token_urlsafe(32)
expiration_time: int = 1440
algorithm: str = "HS256"
token_type: str = "Bearer"
auth_header_field: str = "Authorization"
def get_user(self, user_id: ID_TYPE) -> TUser | None:
return self.users.get(user_id, None)
def add_user(self, user: TUser) -> bool:
if user.id in self.users.keys():
return False
self.users[user.id] = user
return True
def delete_user(self, user_id: ID_TYPE) -> bool:
if user_id not in self.users:
return False
del self.users[user_id]
return True
def update_user(self, user: TUser) -> bool:
if user.id not in self.users:
return False
self.users[user.id] = user
return True
def create_token(self, user: TUser) -> str:
payload = {
**user.model_dump(),
"exp": datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(minutes=self.expiration_time),
}
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
return token
def paginate_users(self, page: int = 1, per_page: int = 10) -> Paginate[TUser]:
"""分页查询用户
Args:
page: 页码从1开始
per_page: 每页数量
Returns:
包含分页信息的字典
"""
return self.users.paginate(page=page, per_page=per_page)
def verify_token(self, token: str) -> TUser | None:
import jwt
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
user_id = payload.get("id")
if user_id is None:
return None
return self.get_user(user_id)
except jwt.ExpiredSignatureError:
raise jwt.ExpiredSignatureError
except jwt.InvalidTokenError:
raise jwt.InvalidTokenError
def refresh_token(self, token: str) -> str | None:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
user_id = payload.get("id")
if user_id is None:
return None
user = self.get_user(user_id)
if user is None:
return None
return self.create_token(user)
def fastapi_get_user(self, request: Request) -> TUser | None:
"""FastAPI dependency to get current user from token"""
auth_header = request.headers.get(self.auth_header_field)
if not auth_header:
raise HTTPException(status_code=401, detail="Invalid token")
parts = auth_header.split()
if len(parts) != 2 or parts[0] != self.token_type:
raise HTTPException(status_code=401, detail="Invalid token")
token = parts[1]
try:
user = self.verify_token(token)
except Exception as e:
raise HTTPException(status_code=401, detail=f"Invalid token {e}")
if not user:
raise HTTPException(status_code=401, detail="User not found")
return user