import asyncio import time from collections import defaultdict from datetime import datetime from typing import Any, Callable, Dict, List, Optional from fastapi import FastAPI, Request, Response from starlette.middleware.base import BaseHTTPMiddleware from plugin.base import Plugin from uvicorn.server import logger class APIMonitorMiddleware(BaseHTTPMiddleware): def __init__( self, app: FastAPI, *, entity: object = None, exclude_paths: Optional[List[str]] = None, rate_limit_window: int = 60, ): super().__init__(app) self.exclude_paths = exclude_paths or [] self.rate_limit_window = rate_limit_window # 用于存储请求频率数据 self.request_counts = defaultdict(int) self.request_timestamps = defaultdict(list) # 存储中间件实例以便后续访问 # app.state.api_monitor = self entity.middleware = self # 启动后台清理任务 self.cleanup_task = asyncio.create_task(self._cleanup_old_records()) async def _cleanup_old_records(self): """定期清理过期的请求记录""" try: while True: await asyncio.sleep(self.rate_limit_window * 2) current_time = time.time() keys_to_remove = [] for key in list(self.request_timestamps.keys()): # 移除超过时间窗口的时间戳 self.request_timestamps[key] = [ ts for ts in self.request_timestamps[key] if current_time - ts < self.rate_limit_window ] if not self.request_timestamps[key]: keys_to_remove.append(key) for key in keys_to_remove: del self.request_timestamps[key] if key in self.request_counts: del self.request_counts[key] except asyncio.CancelledError: # 任务被取消时正常退出 pass async def dispatch(self, request: Request, call_next: Callable) -> Response: # 检查是否排除该路径 if any(request.url.path.startswith(path) for path in self.exclude_paths): return await call_next(request) # 记录开始时间 start_time = time.time() try: # 处理请求 response = await call_next(request) process_time = time.time() - start_time status_code = response.status_code # response.headers["X-Process-Time"] = str(process_time) error = None except Exception as e: process_time = time.time() - start_time response = None status_code = 500 error = str(e) # 重新抛出异常以便FastAPI处理 raise finally: # 无论是否异常都记录日志 # 更新请求频率统计 await self._update_request_stats(request, status_code) # 记录请求信息 await self._log_request( request, status_code, start_time, process_time, error ) return response if response else Response(status_code=status_code) async def _update_request_stats(self, request: Request, status_code: int): """更新请求频率统计""" client_ip = request.client.host if request.client else "unknown" endpoint_key = f"{client_ip}:{request.method}:{request.url.path}" current_time = time.time() self.request_timestamps[endpoint_key].append(current_time) # 计算当前时间窗口内的请求次数 window_start = current_time - self.rate_limit_window self.request_timestamps[endpoint_key] = [ ts for ts in self.request_timestamps[endpoint_key] if ts >= window_start ] self.request_counts[endpoint_key] = len(self.request_timestamps[endpoint_key]) async def _log_request( self, request: Request, status_code: int, start_time: float, process_time: float = 0, error: Optional[str] = None, ): """记录请求日志""" client_ip = request.client.host if request.client else "unknown" user_agent = request.headers.get("user-agent", "") referer = request.headers.get("referer", "") log_data = { "timestamp": datetime.fromtimestamp(start_time).isoformat(), "client_ip": client_ip, "method": request.method, "url": str(request.url), "path": request.url.path, "query_params": dict(request.query_params), # "body": await request.json(), "status_code": status_code, "process_time_ms": round(process_time * 1000, 2), "user_agent": user_agent[:200], # 限制长度 "referer": referer[:200], "content_type": request.headers.get("content-type", ""), "content_length": int(request.headers.get("content-length", 0)), "error": error, } # 拼接为一条字符串 log_entry = f"耗时:{log_data['process_time_ms']}ms,错误:{error}, 请求参数: {str(log_data['query_params'])}, 请求url: {log_data['url']} " # 根据状态码决定日志级别 if status_code >= 500: logger.error(log_entry) elif status_code >= 400: logger.error(log_entry) elif status_code >= 300: logger.info(log_entry) elif status_code >= 200: logger.debug(log_entry) def get_request_stats( self, ip: Optional[str] = None, path: Optional[str] = None ) -> Dict[str, Any]: """获取请求统计信息""" stats = {} current_time = time.time() for key in list(self.request_counts.keys()): parts = key.split(":", 2) if len(parts) < 3: continue client_ip, method, endpoint_path = parts if ip and client_ip != ip: continue if path and endpoint_path != path: continue # 计算各种统计指标 timestamps = self.request_timestamps.get(key, []) recent_requests = [ ts for ts in timestamps if current_time - ts < self.rate_limit_window ] stats[key] = { "total_requests": len(timestamps), "current_window_requests": len(recent_requests), "requests_per_minute": len(recent_requests) * (60 / self.rate_limit_window), "last_request_time": datetime.fromtimestamp(max(timestamps)).isoformat() if timestamps else None, "method": method, "endpoint": endpoint_path, "client_ip": client_ip, } return stats def get_summary_stats(self) -> Dict[str, Any]: """获取汇总统计信息""" total_requests = sum( len(timestamps) for timestamps in self.request_timestamps.values() ) unique_clients = len( set(key.split(":")[0] for key in self.request_timestamps.keys()) ) unique_endpoints = len( set( key.split(":")[2] for key in self.request_timestamps.keys() if len(key.split(":")) >= 3 ) ) return { "total_requests": total_requests, "unique_clients": unique_clients, "unique_endpoints": unique_endpoints, "monitoring_since": datetime.now().isoformat(), "time_window_seconds": self.rate_limit_window, } class Monitor(Plugin): def __init__( self, app: FastAPI, exclude_paths: list = None, rate_limit_window: int = 60 ): """_summary_ Args: app (FastAPI): _description_ exclude_paths (list, optional): _description_. Defaults to None. rate_limit_window (int, optional): _description_. Defaults to 60. """ logger.info("[监控插件] 加载监控插件 🔧") self.app = app self.exclude_paths = exclude_paths self.rate_limit_window = rate_limit_window self.middleware = None self.name = "Monitor" self.version = "1.0.0" async def get_stats(self, ip: str = None, path: str = None): """获取API统计信息""" if self.middleware: return { "summary": self.middleware.get_summary_stats(), "detailed_stats": self.middleware.get_request_stats(ip, path), } return {"error": "Monitor middleware not initialized"} def install(self): self.app.add_middleware( APIMonitorMiddleware, entity=self, exclude_paths=self.exclude_paths, rate_limit_window=self.rate_limit_window, ) self.app.add_api_route( path="/api/stats", endpoint=self.get_stats, methods=["GET"], summary="获取API统计信息", description="获取API统计信息", ) logger.info("[监控插件] 监控插件已加载 ✅")