Initial commit
This commit is contained in:
190
plugin/user/user_manager.py
Normal file
190
plugin/user/user_manager.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import datetime
|
||||
import secrets
|
||||
import uuid
|
||||
from typing import Generic, TypedDict, TypeVar
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import List
|
||||
|
||||
ID_TYPE = str | int
|
||||
|
||||
|
||||
class AuthUser(BaseModel):
|
||||
id: ID_TYPE
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = str(uuid.uuid4())
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Paginate(TypedDict, Generic[T]):
|
||||
items: list[T]
|
||||
total: int
|
||||
pages: int
|
||||
current_page: int
|
||||
per_page: int
|
||||
|
||||
|
||||
class UserSource(Generic[T]):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self._data: dict[ID_TYPE, T] = dict[ID_TYPE, T](*args, **kwargs)
|
||||
|
||||
def __getitem__(self, key: ID_TYPE) -> T:
|
||||
return self._data[key]
|
||||
|
||||
def __setitem__(self, key: ID_TYPE, value: T) -> None:
|
||||
self._data[key] = value
|
||||
|
||||
def __delitem__(self, key: ID_TYPE) -> None:
|
||||
del self._data[key]
|
||||
|
||||
def __contains__(self, key: ID_TYPE) -> bool:
|
||||
return key in self._data
|
||||
|
||||
def keys(self) -> List[ID_TYPE]:
|
||||
return list(self._data.keys())
|
||||
|
||||
def get(self, key: ID_TYPE, default: T | None = None) -> T | None:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def paginate(self, page: int = 1, per_page: int = 10) -> Paginate[T]:
|
||||
"""分页获取用户数据
|
||||
|
||||
Args:
|
||||
page: 页码,从1开始
|
||||
per_page: 每页数量
|
||||
|
||||
Returns:
|
||||
包含分页信息的字典,包括:
|
||||
- items: 当前页的数据列表
|
||||
- total: 总数据量
|
||||
- pages: 总页数
|
||||
- current_page: 当前页码
|
||||
- per_page: 每页数量
|
||||
"""
|
||||
total = len(self._data)
|
||||
pages = (total + per_page - 1) // per_page if per_page > 0 else 0
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page > pages and pages > 0:
|
||||
page = pages
|
||||
|
||||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
|
||||
items = list(self._data.values())[start:end]
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"pages": pages,
|
||||
"current_page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"UserSource({self._data})"
|
||||
|
||||
|
||||
TUser = TypeVar("TUser", bound=AuthUser)
|
||||
|
||||
|
||||
class UserManager(Generic[TUser]):
|
||||
users: UserSource[TUser] = UserSource()
|
||||
secret_key: str = secrets.token_urlsafe(32)
|
||||
expiration_time: int = 1440
|
||||
algorithm: str = "HS256"
|
||||
token_type: str = "Bearer"
|
||||
auth_header_field: str = "Authorization"
|
||||
|
||||
def get_user(self, user_id: ID_TYPE) -> TUser | None:
|
||||
return self.users.get(user_id, None)
|
||||
|
||||
def add_user(self, user: TUser) -> bool:
|
||||
if user.id in self.users.keys():
|
||||
return False
|
||||
self.users[user.id] = user
|
||||
return True
|
||||
|
||||
def delete_user(self, user_id: ID_TYPE) -> bool:
|
||||
if user_id not in self.users:
|
||||
return False
|
||||
del self.users[user_id]
|
||||
return True
|
||||
|
||||
def update_user(self, user: TUser) -> bool:
|
||||
if user.id not in self.users:
|
||||
return False
|
||||
self.users[user.id] = user
|
||||
return True
|
||||
|
||||
def create_token(self, user: TUser) -> str:
|
||||
payload = {
|
||||
**user.model_dump(),
|
||||
"exp": datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(minutes=self.expiration_time),
|
||||
}
|
||||
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
return token
|
||||
|
||||
def paginate_users(self, page: int = 1, per_page: int = 10) -> Paginate[TUser]:
|
||||
"""分页查询用户
|
||||
|
||||
Args:
|
||||
page: 页码,从1开始
|
||||
per_page: 每页数量
|
||||
|
||||
Returns:
|
||||
包含分页信息的字典
|
||||
"""
|
||||
return self.users.paginate(page=page, per_page=per_page)
|
||||
|
||||
def verify_token(self, token: str) -> TUser | None:
|
||||
import jwt
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
user_id = payload.get("id")
|
||||
if user_id is None:
|
||||
return None
|
||||
return self.get_user(user_id)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise jwt.ExpiredSignatureError
|
||||
except jwt.InvalidTokenError:
|
||||
raise jwt.InvalidTokenError
|
||||
|
||||
def refresh_token(self, token: str) -> str | None:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
user_id = payload.get("id")
|
||||
if user_id is None:
|
||||
return None
|
||||
user = self.get_user(user_id)
|
||||
if user is None:
|
||||
return None
|
||||
return self.create_token(user)
|
||||
|
||||
def fastapi_get_user(self, request: Request) -> TUser | None:
|
||||
"""FastAPI dependency to get current user from token"""
|
||||
auth_header = request.headers.get(self.auth_header_field)
|
||||
if not auth_header:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
parts = auth_header.split()
|
||||
if len(parts) != 2 or parts[0] != self.token_type:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
token = parts[1]
|
||||
try:
|
||||
user = self.verify_token(token)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=401, detail=f"Invalid token {e}")
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
return user
|
||||
Reference in New Issue
Block a user