62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
from fastapi import HTTPException, status, Request, Depends
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from typing import Optional
|
|
from app.services.jwt_service import jwt_service
|
|
from app.core.database import get_db_connection
|
|
|
|
security = HTTPBearer()
|
|
|
|
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""获取当前用户信息的依赖函数"""
|
|
token = credentials.credentials
|
|
|
|
# 验证JWT token
|
|
payload = jwt_service.verify_token(token)
|
|
if not payload:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired token",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# 从数据库验证用户是否仍然存在且有效
|
|
user_id = payload.get("user_id")
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT user_id, username, caption, email FROM users WHERE user_id = %s",
|
|
(user_id,)
|
|
)
|
|
user = cursor.fetchone()
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return user
|
|
|
|
def get_optional_current_user(request: Request) -> Optional[dict]:
|
|
"""可选的用户认证(不强制要求登录)"""
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header or not auth_header.startswith("Bearer "):
|
|
return None
|
|
|
|
try:
|
|
token = auth_header.split(" ")[1]
|
|
payload = jwt_service.verify_token(token)
|
|
if not payload:
|
|
return None
|
|
|
|
user_id = payload.get("user_id")
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT user_id, username, caption, email FROM users WHERE user_id = %s",
|
|
(user_id,)
|
|
)
|
|
return cursor.fetchone()
|
|
except:
|
|
return None |