Files
intelligent-daily-report-sy…/plugin/user/user_manager.py
2026-02-25 15:22:23 +08:00

191 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import datetime
import secrets
import uuid
from typing import Generic, TypedDict, TypeVar
import jwt
from fastapi import HTTPException, Request
from pydantic import BaseModel
from typing_extensions import List
ID_TYPE = str | int
class AuthUser(BaseModel):
id: ID_TYPE
def __init__(self, *args, **kwargs):
if "id" not in kwargs:
kwargs["id"] = str(uuid.uuid4())
super().__init__(*args, **kwargs)
T = TypeVar("T")
class Paginate(TypedDict, Generic[T]):
items: list[T]
total: int
pages: int
current_page: int
per_page: int
class UserSource(Generic[T]):
def __init__(self, *args, **kwargs) -> None:
self._data: dict[ID_TYPE, T] = dict[ID_TYPE, T](*args, **kwargs)
def __getitem__(self, key: ID_TYPE) -> T:
return self._data[key]
def __setitem__(self, key: ID_TYPE, value: T) -> None:
self._data[key] = value
def __delitem__(self, key: ID_TYPE) -> None:
del self._data[key]
def __contains__(self, key: ID_TYPE) -> bool:
return key in self._data
def keys(self) -> List[ID_TYPE]:
return list(self._data.keys())
def get(self, key: ID_TYPE, default: T | None = None) -> T | None:
return self._data.get(key, default)
def paginate(self, page: int = 1, per_page: int = 10) -> Paginate[T]:
"""分页获取用户数据
Args:
page: 页码从1开始
per_page: 每页数量
Returns:
包含分页信息的字典,包括:
- items: 当前页的数据列表
- total: 总数据量
- pages: 总页数
- current_page: 当前页码
- per_page: 每页数量
"""
total = len(self._data)
pages = (total + per_page - 1) // per_page if per_page > 0 else 0
if page < 1:
page = 1
if page > pages and pages > 0:
page = pages
start = (page - 1) * per_page
end = start + per_page
items = list(self._data.values())[start:end]
return {
"items": items,
"total": total,
"pages": pages,
"current_page": page,
"per_page": per_page,
}
def __repr__(self) -> str:
return f"UserSource({self._data})"
TUser = TypeVar("TUser", bound=AuthUser)
class UserManager(Generic[TUser]):
users: UserSource[TUser] = UserSource()
secret_key: str = secrets.token_urlsafe(32)
expiration_time: int = 1440
algorithm: str = "HS256"
token_type: str = "Bearer"
auth_header_field: str = "Authorization"
def get_user(self, user_id: ID_TYPE) -> TUser | None:
return self.users.get(user_id, None)
def add_user(self, user: TUser) -> bool:
if user.id in self.users.keys():
return False
self.users[user.id] = user
return True
def delete_user(self, user_id: ID_TYPE) -> bool:
if user_id not in self.users:
return False
del self.users[user_id]
return True
def update_user(self, user: TUser) -> bool:
if user.id not in self.users:
return False
self.users[user.id] = user
return True
def create_token(self, user: TUser) -> str:
payload = {
**user.model_dump(),
"exp": datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(minutes=self.expiration_time),
}
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
return token
def paginate_users(self, page: int = 1, per_page: int = 10) -> Paginate[TUser]:
"""分页查询用户
Args:
page: 页码从1开始
per_page: 每页数量
Returns:
包含分页信息的字典
"""
return self.users.paginate(page=page, per_page=per_page)
def verify_token(self, token: str) -> TUser | None:
import jwt
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
user_id = payload.get("id")
if user_id is None:
return None
return self.get_user(user_id)
except jwt.ExpiredSignatureError:
raise jwt.ExpiredSignatureError
except jwt.InvalidTokenError:
raise jwt.InvalidTokenError
def refresh_token(self, token: str) -> str | None:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
user_id = payload.get("id")
if user_id is None:
return None
user = self.get_user(user_id)
if user is None:
return None
return self.create_token(user)
def fastapi_get_user(self, request: Request) -> TUser | None:
"""FastAPI dependency to get current user from token"""
auth_header = request.headers.get(self.auth_header_field)
if not auth_header:
raise HTTPException(status_code=401, detail="Invalid token")
parts = auth_header.split()
if len(parts) != 2 or parts[0] != self.token_type:
raise HTTPException(status_code=401, detail="Invalid token")
token = parts[1]
try:
user = self.verify_token(token)
except Exception as e:
raise HTTPException(status_code=401, detail=f"Invalid token {e}")
if not user:
raise HTTPException(status_code=401, detail="User not found")
return user