215 lines
6.8 KiB
Python
215 lines
6.8 KiB
Python
"""
|
||
认证依赖:获取当前登录用户
|
||
"""
|
||
from fastapi import Depends, HTTPException, status, Request
|
||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from typing import Optional
|
||
from app.core.database import get_db
|
||
from app.core.security import decode_access_token
|
||
from app.core.redis_client import TokenCache
|
||
from app.models.user import User
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# HTTP Bearer 认证方案
|
||
security = HTTPBearer()
|
||
security_optional = HTTPBearer(auto_error=False)
|
||
|
||
|
||
async def get_current_user(
|
||
request: Request,
|
||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||
db: AsyncSession = Depends(get_db)
|
||
) -> User:
|
||
"""
|
||
获取当前登录用户(依赖注入)
|
||
"""
|
||
token = credentials.credentials
|
||
logger.info(f"Received token: {token[:20]}...") # 只记录前20个字符
|
||
|
||
# 保存 token 到请求状态,供退出登录使用
|
||
request.state.token = token
|
||
|
||
# 先验证 Redis 中是否存在该 token
|
||
user_id_from_redis = await TokenCache.get_user_id(token)
|
||
if user_id_from_redis is None:
|
||
logger.error("Token not found in Redis or expired")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="登录已过期,请重新登录",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
# 解码 JWT 验证完整性
|
||
payload = decode_access_token(token)
|
||
logger.info(f"Decoded payload: {payload}")
|
||
|
||
if payload is None:
|
||
logger.error("Token decode failed: payload is None")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭证",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
user_id_str = payload.get("sub")
|
||
logger.info(f"Extracted user_id (string): {user_id_str}")
|
||
|
||
if user_id_str is None:
|
||
logger.error("user_id is None in payload")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭证",
|
||
)
|
||
|
||
# 将字符串转为整数
|
||
try:
|
||
user_id = int(user_id_str)
|
||
except (ValueError, TypeError):
|
||
logger.error(f"Invalid user_id format: {user_id_str}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭证",
|
||
)
|
||
|
||
# 验证 Redis 中的 user_id 与 JWT 中的是否一致
|
||
if user_id != user_id_from_redis:
|
||
logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭证",
|
||
)
|
||
|
||
# 查询用户
|
||
result = await db.execute(select(User).where(User.id == user_id))
|
||
user = result.scalar_one_or_none()
|
||
|
||
if user is None:
|
||
logger.error(f"User not found for user_id: {user_id}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户不存在",
|
||
)
|
||
|
||
if user.status != 1:
|
||
logger.error(f"User {user_id} is disabled")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户已被禁用",
|
||
)
|
||
|
||
logger.info(f"User authenticated successfully: {user.username}")
|
||
return user
|
||
|
||
|
||
async def get_current_active_user(
|
||
current_user: User = Depends(get_current_user)
|
||
) -> User:
|
||
"""
|
||
获取当前活跃用户
|
||
"""
|
||
return current_user
|
||
|
||
|
||
async def get_user_from_token_or_query(
|
||
request: Request,
|
||
token: Optional[str] = None, # 从query参数获取
|
||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional),
|
||
db: AsyncSession = Depends(get_db)
|
||
) -> User:
|
||
"""
|
||
获取当前用户(支持从query参数或header获取token)
|
||
用于图片等资源访问,优先使用header,其次使用query参数
|
||
"""
|
||
# 优先从 header 获取 token
|
||
if credentials:
|
||
token_str = credentials.credentials
|
||
elif token:
|
||
# 从 query 参数获取
|
||
token_str = token
|
||
else:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="未提供认证凭证",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
logger.info(f"Received token: {token_str[:20]}...")
|
||
|
||
# 保存 token 到请求状态
|
||
request.state.token = token_str
|
||
|
||
# 验证 Redis 中是否存在该 token
|
||
user_id_from_redis = await TokenCache.get_user_id(token_str)
|
||
if user_id_from_redis is None:
|
||
logger.error("Token not found in Redis or expired")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="登录已过期,请重新登录",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
# 解码 JWT 验证完整性
|
||
payload = decode_access_token(token_str)
|
||
logger.info(f"Decoded payload: {payload}")
|
||
|
||
if payload is None:
|
||
logger.error("Token decode failed: payload is None")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭证",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
user_id_str = payload.get("sub")
|
||
logger.info(f"Extracted user_id (string): {user_id_str}")
|
||
|
||
if user_id_str is None:
|
||
logger.error("user_id is None in payload")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭证",
|
||
)
|
||
|
||
# 将字符串转为整数
|
||
try:
|
||
user_id = int(user_id_str)
|
||
except (ValueError, TypeError):
|
||
logger.error(f"Invalid user_id format: {user_id_str}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭证",
|
||
)
|
||
|
||
# 验证 Redis 中的 user_id 与 JWT 中的是否一致
|
||
if user_id != user_id_from_redis:
|
||
logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭证",
|
||
)
|
||
|
||
# 查询用户
|
||
result = await db.execute(select(User).where(User.id == user_id))
|
||
user = result.scalar_one_or_none()
|
||
|
||
if user is None:
|
||
logger.error(f"User not found for user_id: {user_id}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户不存在",
|
||
)
|
||
|
||
if user.status != 1:
|
||
logger.error(f"User {user_id} is disabled")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户已被禁用",
|
||
)
|
||
|
||
logger.info(f"User authenticated successfully: {user.username}")
|
||
return user
|