Files
intelligent-daily-report-sy…/plugin/monitor/middleware.py
2026-02-25 15:22:23 +08:00

268 lines
9.2 KiB
Python

import asyncio
import time
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from plugin.base import Plugin
from uvicorn.server import logger
class APIMonitorMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: FastAPI,
*,
entity: object = None,
exclude_paths: Optional[List[str]] = None,
rate_limit_window: int = 60,
):
super().__init__(app)
self.exclude_paths = exclude_paths or []
self.rate_limit_window = rate_limit_window
# 用于存储请求频率数据
self.request_counts = defaultdict(int)
self.request_timestamps = defaultdict(list)
# 存储中间件实例以便后续访问
# app.state.api_monitor = self
entity.middleware = self
# 启动后台清理任务
self.cleanup_task = asyncio.create_task(self._cleanup_old_records())
async def _cleanup_old_records(self):
"""定期清理过期的请求记录"""
try:
while True:
await asyncio.sleep(self.rate_limit_window * 2)
current_time = time.time()
keys_to_remove = []
for key in list(self.request_timestamps.keys()):
# 移除超过时间窗口的时间戳
self.request_timestamps[key] = [
ts
for ts in self.request_timestamps[key]
if current_time - ts < self.rate_limit_window
]
if not self.request_timestamps[key]:
keys_to_remove.append(key)
for key in keys_to_remove:
del self.request_timestamps[key]
if key in self.request_counts:
del self.request_counts[key]
except asyncio.CancelledError:
# 任务被取消时正常退出
pass
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 检查是否排除该路径
if any(request.url.path.startswith(path) for path in self.exclude_paths):
return await call_next(request)
# 记录开始时间
start_time = time.time()
try:
# 处理请求
response = await call_next(request)
process_time = time.time() - start_time
status_code = response.status_code
# response.headers["X-Process-Time"] = str(process_time)
error = None
except Exception as e:
process_time = time.time() - start_time
response = None
status_code = 500
error = str(e)
# 重新抛出异常以便FastAPI处理
raise
finally:
# 无论是否异常都记录日志
# 更新请求频率统计
await self._update_request_stats(request, status_code)
# 记录请求信息
await self._log_request(
request, status_code, start_time, process_time, error
)
return response if response else Response(status_code=status_code)
async def _update_request_stats(self, request: Request, status_code: int):
"""更新请求频率统计"""
client_ip = request.client.host if request.client else "unknown"
endpoint_key = f"{client_ip}:{request.method}:{request.url.path}"
current_time = time.time()
self.request_timestamps[endpoint_key].append(current_time)
# 计算当前时间窗口内的请求次数
window_start = current_time - self.rate_limit_window
self.request_timestamps[endpoint_key] = [
ts for ts in self.request_timestamps[endpoint_key] if ts >= window_start
]
self.request_counts[endpoint_key] = len(self.request_timestamps[endpoint_key])
async def _log_request(
self,
request: Request,
status_code: int,
start_time: float,
process_time: float = 0,
error: Optional[str] = None,
):
"""记录请求日志"""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
referer = request.headers.get("referer", "")
log_data = {
"timestamp": datetime.fromtimestamp(start_time).isoformat(),
"client_ip": client_ip,
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
# "body": await request.json(),
"status_code": status_code,
"process_time_ms": round(process_time * 1000, 2),
"user_agent": user_agent[:200], # 限制长度
"referer": referer[:200],
"content_type": request.headers.get("content-type", ""),
"content_length": int(request.headers.get("content-length", 0)),
"error": error,
}
# 拼接为一条字符串
log_entry = f"耗时:{log_data['process_time_ms']}ms,错误:{error}, 请求参数: {str(log_data['query_params'])}, 请求url: {log_data['url']} "
# 根据状态码决定日志级别
if status_code >= 500:
logger.error(log_entry)
elif status_code >= 400:
logger.error(log_entry)
elif status_code >= 300:
logger.info(log_entry)
elif status_code >= 200:
logger.debug(log_entry)
def get_request_stats(
self, ip: Optional[str] = None, path: Optional[str] = None
) -> Dict[str, Any]:
"""获取请求统计信息"""
stats = {}
current_time = time.time()
for key in list(self.request_counts.keys()):
parts = key.split(":", 2)
if len(parts) < 3:
continue
client_ip, method, endpoint_path = parts
if ip and client_ip != ip:
continue
if path and endpoint_path != path:
continue
# 计算各种统计指标
timestamps = self.request_timestamps.get(key, [])
recent_requests = [
ts for ts in timestamps if current_time - ts < self.rate_limit_window
]
stats[key] = {
"total_requests": len(timestamps),
"current_window_requests": len(recent_requests),
"requests_per_minute": len(recent_requests)
* (60 / self.rate_limit_window),
"last_request_time": datetime.fromtimestamp(max(timestamps)).isoformat()
if timestamps
else None,
"method": method,
"endpoint": endpoint_path,
"client_ip": client_ip,
}
return stats
def get_summary_stats(self) -> Dict[str, Any]:
"""获取汇总统计信息"""
total_requests = sum(
len(timestamps) for timestamps in self.request_timestamps.values()
)
unique_clients = len(
set(key.split(":")[0] for key in self.request_timestamps.keys())
)
unique_endpoints = len(
set(
key.split(":")[2]
for key in self.request_timestamps.keys()
if len(key.split(":")) >= 3
)
)
return {
"total_requests": total_requests,
"unique_clients": unique_clients,
"unique_endpoints": unique_endpoints,
"monitoring_since": datetime.now().isoformat(),
"time_window_seconds": self.rate_limit_window,
}
class Monitor(Plugin):
def __init__(
self, app: FastAPI, exclude_paths: list = None, rate_limit_window: int = 60
):
"""_summary_
Args:
app (FastAPI): _description_
exclude_paths (list, optional): _description_. Defaults to None.
rate_limit_window (int, optional): _description_. Defaults to 60.
"""
logger.info("[监控插件] 加载监控插件 🔧")
self.app = app
self.exclude_paths = exclude_paths
self.rate_limit_window = rate_limit_window
self.middleware = None
self.name = "Monitor"
self.version = "1.0.0"
async def get_stats(self, ip: str = None, path: str = None):
"""获取API统计信息"""
if self.middleware:
return {
"summary": self.middleware.get_summary_stats(),
"detailed_stats": self.middleware.get_request_stats(ip, path),
}
return {"error": "Monitor middleware not initialized"}
def install(self):
self.app.add_middleware(
APIMonitorMiddleware,
entity=self,
exclude_paths=self.exclude_paths,
rate_limit_window=self.rate_limit_window,
)
self.app.add_api_route(
path="/api/stats",
endpoint=self.get_stats,
methods=["GET"],
summary="获取API统计信息",
description="获取API统计信息",
)
logger.info("[监控插件] 监控插件已加载 ✅")