191 lines
5.4 KiB
Python
191 lines
5.4 KiB
Python
import datetime
|
||
import secrets
|
||
import uuid
|
||
from typing import Generic, TypedDict, TypeVar
|
||
|
||
import jwt
|
||
from fastapi import HTTPException, Request
|
||
from pydantic import BaseModel
|
||
from typing_extensions import List
|
||
|
||
ID_TYPE = str | int
|
||
|
||
|
||
class AuthUser(BaseModel):
|
||
id: ID_TYPE
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
if "id" not in kwargs:
|
||
kwargs["id"] = str(uuid.uuid4())
|
||
super().__init__(*args, **kwargs)
|
||
|
||
|
||
T = TypeVar("T")
|
||
|
||
|
||
class Paginate(TypedDict, Generic[T]):
|
||
items: list[T]
|
||
total: int
|
||
pages: int
|
||
current_page: int
|
||
per_page: int
|
||
|
||
|
||
class UserSource(Generic[T]):
|
||
def __init__(self, *args, **kwargs) -> None:
|
||
self._data: dict[ID_TYPE, T] = dict[ID_TYPE, T](*args, **kwargs)
|
||
|
||
def __getitem__(self, key: ID_TYPE) -> T:
|
||
return self._data[key]
|
||
|
||
def __setitem__(self, key: ID_TYPE, value: T) -> None:
|
||
self._data[key] = value
|
||
|
||
def __delitem__(self, key: ID_TYPE) -> None:
|
||
del self._data[key]
|
||
|
||
def __contains__(self, key: ID_TYPE) -> bool:
|
||
return key in self._data
|
||
|
||
def keys(self) -> List[ID_TYPE]:
|
||
return list(self._data.keys())
|
||
|
||
def get(self, key: ID_TYPE, default: T | None = None) -> T | None:
|
||
return self._data.get(key, default)
|
||
|
||
def paginate(self, page: int = 1, per_page: int = 10) -> Paginate[T]:
|
||
"""分页获取用户数据
|
||
|
||
Args:
|
||
page: 页码,从1开始
|
||
per_page: 每页数量
|
||
|
||
Returns:
|
||
包含分页信息的字典,包括:
|
||
- items: 当前页的数据列表
|
||
- total: 总数据量
|
||
- pages: 总页数
|
||
- current_page: 当前页码
|
||
- per_page: 每页数量
|
||
"""
|
||
total = len(self._data)
|
||
pages = (total + per_page - 1) // per_page if per_page > 0 else 0
|
||
|
||
if page < 1:
|
||
page = 1
|
||
if page > pages and pages > 0:
|
||
page = pages
|
||
|
||
start = (page - 1) * per_page
|
||
end = start + per_page
|
||
|
||
items = list(self._data.values())[start:end]
|
||
|
||
return {
|
||
"items": items,
|
||
"total": total,
|
||
"pages": pages,
|
||
"current_page": page,
|
||
"per_page": per_page,
|
||
}
|
||
|
||
def __repr__(self) -> str:
|
||
return f"UserSource({self._data})"
|
||
|
||
|
||
TUser = TypeVar("TUser", bound=AuthUser)
|
||
|
||
|
||
class UserManager(Generic[TUser]):
|
||
users: UserSource[TUser] = UserSource()
|
||
secret_key: str = secrets.token_urlsafe(32)
|
||
expiration_time: int = 1440
|
||
algorithm: str = "HS256"
|
||
token_type: str = "Bearer"
|
||
auth_header_field: str = "Authorization"
|
||
|
||
def get_user(self, user_id: ID_TYPE) -> TUser | None:
|
||
return self.users.get(user_id, None)
|
||
|
||
def add_user(self, user: TUser) -> bool:
|
||
if user.id in self.users.keys():
|
||
return False
|
||
self.users[user.id] = user
|
||
return True
|
||
|
||
def delete_user(self, user_id: ID_TYPE) -> bool:
|
||
if user_id not in self.users:
|
||
return False
|
||
del self.users[user_id]
|
||
return True
|
||
|
||
def update_user(self, user: TUser) -> bool:
|
||
if user.id not in self.users:
|
||
return False
|
||
self.users[user.id] = user
|
||
return True
|
||
|
||
def create_token(self, user: TUser) -> str:
|
||
payload = {
|
||
**user.model_dump(),
|
||
"exp": datetime.datetime.now(datetime.timezone.utc)
|
||
+ datetime.timedelta(minutes=self.expiration_time),
|
||
}
|
||
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||
return token
|
||
|
||
def paginate_users(self, page: int = 1, per_page: int = 10) -> Paginate[TUser]:
|
||
"""分页查询用户
|
||
|
||
Args:
|
||
page: 页码,从1开始
|
||
per_page: 每页数量
|
||
|
||
Returns:
|
||
包含分页信息的字典
|
||
"""
|
||
return self.users.paginate(page=page, per_page=per_page)
|
||
|
||
def verify_token(self, token: str) -> TUser | None:
|
||
import jwt
|
||
|
||
try:
|
||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||
user_id = payload.get("id")
|
||
if user_id is None:
|
||
return None
|
||
return self.get_user(user_id)
|
||
except jwt.ExpiredSignatureError:
|
||
raise jwt.ExpiredSignatureError
|
||
except jwt.InvalidTokenError:
|
||
raise jwt.InvalidTokenError
|
||
|
||
def refresh_token(self, token: str) -> str | None:
|
||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||
user_id = payload.get("id")
|
||
if user_id is None:
|
||
return None
|
||
user = self.get_user(user_id)
|
||
if user is None:
|
||
return None
|
||
return self.create_token(user)
|
||
|
||
def fastapi_get_user(self, request: Request) -> TUser | None:
|
||
"""FastAPI dependency to get current user from token"""
|
||
auth_header = request.headers.get(self.auth_header_field)
|
||
if not auth_header:
|
||
raise HTTPException(status_code=401, detail="Invalid token")
|
||
|
||
parts = auth_header.split()
|
||
if len(parts) != 2 or parts[0] != self.token_type:
|
||
raise HTTPException(status_code=401, detail="Invalid token")
|
||
|
||
token = parts[1]
|
||
try:
|
||
user = self.verify_token(token)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=401, detail=f"Invalid token {e}")
|
||
if not user:
|
||
raise HTTPException(status_code=401, detail="User not found")
|
||
return user
|