268 lines
9.2 KiB
Python
268 lines
9.2 KiB
Python
import asyncio
|
|
import time
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from fastapi import FastAPI, Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from plugin.base import Plugin
|
|
from uvicorn.server import logger
|
|
|
|
|
|
class APIMonitorMiddleware(BaseHTTPMiddleware):
|
|
def __init__(
|
|
self,
|
|
app: FastAPI,
|
|
*,
|
|
entity: object = None,
|
|
exclude_paths: Optional[List[str]] = None,
|
|
rate_limit_window: int = 60,
|
|
):
|
|
super().__init__(app)
|
|
self.exclude_paths = exclude_paths or []
|
|
self.rate_limit_window = rate_limit_window
|
|
|
|
# 用于存储请求频率数据
|
|
self.request_counts = defaultdict(int)
|
|
self.request_timestamps = defaultdict(list)
|
|
|
|
# 存储中间件实例以便后续访问
|
|
# app.state.api_monitor = self
|
|
entity.middleware = self
|
|
|
|
# 启动后台清理任务
|
|
self.cleanup_task = asyncio.create_task(self._cleanup_old_records())
|
|
|
|
async def _cleanup_old_records(self):
|
|
"""定期清理过期的请求记录"""
|
|
try:
|
|
while True:
|
|
await asyncio.sleep(self.rate_limit_window * 2)
|
|
current_time = time.time()
|
|
keys_to_remove = []
|
|
|
|
for key in list(self.request_timestamps.keys()):
|
|
# 移除超过时间窗口的时间戳
|
|
self.request_timestamps[key] = [
|
|
ts
|
|
for ts in self.request_timestamps[key]
|
|
if current_time - ts < self.rate_limit_window
|
|
]
|
|
if not self.request_timestamps[key]:
|
|
keys_to_remove.append(key)
|
|
|
|
for key in keys_to_remove:
|
|
del self.request_timestamps[key]
|
|
if key in self.request_counts:
|
|
del self.request_counts[key]
|
|
|
|
except asyncio.CancelledError:
|
|
# 任务被取消时正常退出
|
|
pass
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
# 检查是否排除该路径
|
|
if any(request.url.path.startswith(path) for path in self.exclude_paths):
|
|
return await call_next(request)
|
|
|
|
# 记录开始时间
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# 处理请求
|
|
response = await call_next(request)
|
|
process_time = time.time() - start_time
|
|
status_code = response.status_code
|
|
# response.headers["X-Process-Time"] = str(process_time)
|
|
error = None
|
|
except Exception as e:
|
|
process_time = time.time() - start_time
|
|
response = None
|
|
status_code = 500
|
|
error = str(e)
|
|
# 重新抛出异常以便FastAPI处理
|
|
raise
|
|
|
|
finally:
|
|
# 无论是否异常都记录日志
|
|
# 更新请求频率统计
|
|
await self._update_request_stats(request, status_code)
|
|
|
|
# 记录请求信息
|
|
await self._log_request(
|
|
request, status_code, start_time, process_time, error
|
|
)
|
|
|
|
return response if response else Response(status_code=status_code)
|
|
|
|
async def _update_request_stats(self, request: Request, status_code: int):
|
|
"""更新请求频率统计"""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
endpoint_key = f"{client_ip}:{request.method}:{request.url.path}"
|
|
|
|
current_time = time.time()
|
|
self.request_timestamps[endpoint_key].append(current_time)
|
|
|
|
# 计算当前时间窗口内的请求次数
|
|
window_start = current_time - self.rate_limit_window
|
|
self.request_timestamps[endpoint_key] = [
|
|
ts for ts in self.request_timestamps[endpoint_key] if ts >= window_start
|
|
]
|
|
self.request_counts[endpoint_key] = len(self.request_timestamps[endpoint_key])
|
|
|
|
async def _log_request(
|
|
self,
|
|
request: Request,
|
|
status_code: int,
|
|
start_time: float,
|
|
process_time: float = 0,
|
|
error: Optional[str] = None,
|
|
):
|
|
"""记录请求日志"""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "")
|
|
referer = request.headers.get("referer", "")
|
|
|
|
log_data = {
|
|
"timestamp": datetime.fromtimestamp(start_time).isoformat(),
|
|
"client_ip": client_ip,
|
|
"method": request.method,
|
|
"url": str(request.url),
|
|
"path": request.url.path,
|
|
"query_params": dict(request.query_params),
|
|
# "body": await request.json(),
|
|
"status_code": status_code,
|
|
"process_time_ms": round(process_time * 1000, 2),
|
|
"user_agent": user_agent[:200], # 限制长度
|
|
"referer": referer[:200],
|
|
"content_type": request.headers.get("content-type", ""),
|
|
"content_length": int(request.headers.get("content-length", 0)),
|
|
"error": error,
|
|
}
|
|
|
|
# 拼接为一条字符串
|
|
log_entry = f"耗时:{log_data['process_time_ms']}ms,错误:{error}, 请求参数: {str(log_data['query_params'])}, 请求url: {log_data['url']} "
|
|
|
|
# 根据状态码决定日志级别
|
|
if status_code >= 500:
|
|
logger.error(log_entry)
|
|
elif status_code >= 400:
|
|
logger.error(log_entry)
|
|
elif status_code >= 300:
|
|
logger.info(log_entry)
|
|
elif status_code >= 200:
|
|
logger.debug(log_entry)
|
|
|
|
def get_request_stats(
|
|
self, ip: Optional[str] = None, path: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""获取请求统计信息"""
|
|
stats = {}
|
|
current_time = time.time()
|
|
|
|
for key in list(self.request_counts.keys()):
|
|
parts = key.split(":", 2)
|
|
if len(parts) < 3:
|
|
continue
|
|
|
|
client_ip, method, endpoint_path = parts
|
|
|
|
if ip and client_ip != ip:
|
|
continue
|
|
if path and endpoint_path != path:
|
|
continue
|
|
|
|
# 计算各种统计指标
|
|
timestamps = self.request_timestamps.get(key, [])
|
|
recent_requests = [
|
|
ts for ts in timestamps if current_time - ts < self.rate_limit_window
|
|
]
|
|
|
|
stats[key] = {
|
|
"total_requests": len(timestamps),
|
|
"current_window_requests": len(recent_requests),
|
|
"requests_per_minute": len(recent_requests)
|
|
* (60 / self.rate_limit_window),
|
|
"last_request_time": datetime.fromtimestamp(max(timestamps)).isoformat()
|
|
if timestamps
|
|
else None,
|
|
"method": method,
|
|
"endpoint": endpoint_path,
|
|
"client_ip": client_ip,
|
|
}
|
|
|
|
return stats
|
|
|
|
def get_summary_stats(self) -> Dict[str, Any]:
|
|
"""获取汇总统计信息"""
|
|
total_requests = sum(
|
|
len(timestamps) for timestamps in self.request_timestamps.values()
|
|
)
|
|
unique_clients = len(
|
|
set(key.split(":")[0] for key in self.request_timestamps.keys())
|
|
)
|
|
unique_endpoints = len(
|
|
set(
|
|
key.split(":")[2]
|
|
for key in self.request_timestamps.keys()
|
|
if len(key.split(":")) >= 3
|
|
)
|
|
)
|
|
|
|
return {
|
|
"total_requests": total_requests,
|
|
"unique_clients": unique_clients,
|
|
"unique_endpoints": unique_endpoints,
|
|
"monitoring_since": datetime.now().isoformat(),
|
|
"time_window_seconds": self.rate_limit_window,
|
|
}
|
|
|
|
|
|
class Monitor(Plugin):
|
|
def __init__(
|
|
self, app: FastAPI, exclude_paths: list = None, rate_limit_window: int = 60
|
|
):
|
|
"""_summary_
|
|
|
|
Args:
|
|
app (FastAPI): _description_
|
|
exclude_paths (list, optional): _description_. Defaults to None.
|
|
rate_limit_window (int, optional): _description_. Defaults to 60.
|
|
"""
|
|
logger.info("[监控插件] 加载监控插件 🔧")
|
|
self.app = app
|
|
self.exclude_paths = exclude_paths
|
|
self.rate_limit_window = rate_limit_window
|
|
self.middleware = None
|
|
self.name = "Monitor"
|
|
self.version = "1.0.0"
|
|
|
|
async def get_stats(self, ip: str = None, path: str = None):
|
|
"""获取API统计信息"""
|
|
if self.middleware:
|
|
return {
|
|
"summary": self.middleware.get_summary_stats(),
|
|
"detailed_stats": self.middleware.get_request_stats(ip, path),
|
|
}
|
|
return {"error": "Monitor middleware not initialized"}
|
|
|
|
def install(self):
|
|
self.app.add_middleware(
|
|
APIMonitorMiddleware,
|
|
entity=self,
|
|
exclude_paths=self.exclude_paths,
|
|
rate_limit_window=self.rate_limit_window,
|
|
)
|
|
|
|
self.app.add_api_route(
|
|
path="/api/stats",
|
|
endpoint=self.get_stats,
|
|
methods=["GET"],
|
|
summary="获取API统计信息",
|
|
description="获取API统计信息",
|
|
)
|
|
|
|
logger.info("[监控插件] 监控插件已加载 ✅")
|