import datetime import secrets import uuid from typing import Generic, TypedDict, TypeVar import jwt from fastapi import HTTPException, Request from pydantic import BaseModel from typing_extensions import List ID_TYPE = str | int class AuthUser(BaseModel): id: ID_TYPE def __init__(self, *args, **kwargs): if "id" not in kwargs: kwargs["id"] = str(uuid.uuid4()) super().__init__(*args, **kwargs) T = TypeVar("T") class Paginate(TypedDict, Generic[T]): items: list[T] total: int pages: int current_page: int per_page: int class UserSource(Generic[T]): def __init__(self, *args, **kwargs) -> None: self._data: dict[ID_TYPE, T] = dict[ID_TYPE, T](*args, **kwargs) def __getitem__(self, key: ID_TYPE) -> T: return self._data[key] def __setitem__(self, key: ID_TYPE, value: T) -> None: self._data[key] = value def __delitem__(self, key: ID_TYPE) -> None: del self._data[key] def __contains__(self, key: ID_TYPE) -> bool: return key in self._data def keys(self) -> List[ID_TYPE]: return list(self._data.keys()) def get(self, key: ID_TYPE, default: T | None = None) -> T | None: return self._data.get(key, default) def paginate(self, page: int = 1, per_page: int = 10) -> Paginate[T]: """分页获取用户数据 Args: page: 页码,从1开始 per_page: 每页数量 Returns: 包含分页信息的字典,包括: - items: 当前页的数据列表 - total: 总数据量 - pages: 总页数 - current_page: 当前页码 - per_page: 每页数量 """ total = len(self._data) pages = (total + per_page - 1) // per_page if per_page > 0 else 0 if page < 1: page = 1 if page > pages and pages > 0: page = pages start = (page - 1) * per_page end = start + per_page items = list(self._data.values())[start:end] return { "items": items, "total": total, "pages": pages, "current_page": page, "per_page": per_page, } def __repr__(self) -> str: return f"UserSource({self._data})" TUser = TypeVar("TUser", bound=AuthUser) class UserManager(Generic[TUser]): users: UserSource[TUser] = UserSource() secret_key: str = secrets.token_urlsafe(32) expiration_time: int = 1440 algorithm: str = "HS256" token_type: str = "Bearer" auth_header_field: str = "Authorization" def get_user(self, user_id: ID_TYPE) -> TUser | None: return self.users.get(user_id, None) def add_user(self, user: TUser) -> bool: if user.id in self.users.keys(): return False self.users[user.id] = user return True def delete_user(self, user_id: ID_TYPE) -> bool: if user_id not in self.users: return False del self.users[user_id] return True def update_user(self, user: TUser) -> bool: if user.id not in self.users: return False self.users[user.id] = user return True def create_token(self, user: TUser) -> str: payload = { **user.model_dump(), "exp": datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=self.expiration_time), } token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) return token def paginate_users(self, page: int = 1, per_page: int = 10) -> Paginate[TUser]: """分页查询用户 Args: page: 页码,从1开始 per_page: 每页数量 Returns: 包含分页信息的字典 """ return self.users.paginate(page=page, per_page=per_page) def verify_token(self, token: str) -> TUser | None: import jwt try: payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) user_id = payload.get("id") if user_id is None: return None return self.get_user(user_id) except jwt.ExpiredSignatureError: raise jwt.ExpiredSignatureError except jwt.InvalidTokenError: raise jwt.InvalidTokenError def refresh_token(self, token: str) -> str | None: payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) user_id = payload.get("id") if user_id is None: return None user = self.get_user(user_id) if user is None: return None return self.create_token(user) def fastapi_get_user(self, request: Request) -> TUser | None: """FastAPI dependency to get current user from token""" auth_header = request.headers.get(self.auth_header_field) if not auth_header: raise HTTPException(status_code=401, detail="Invalid token") parts = auth_header.split() if len(parts) != 2 or parts[0] != self.token_type: raise HTTPException(status_code=401, detail="Invalid token") token = parts[1] try: user = self.verify_token(token) except Exception as e: raise HTTPException(status_code=401, detail=f"Invalid token {e}") if not user: raise HTTPException(status_code=401, detail="User not found") return user