""" 认证依赖:获取当前登录用户 """ from fastapi import Depends, HTTPException, status, Request from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from typing import Optional from app.core.database import get_db from app.core.security import decode_access_token from app.core.redis_client import TokenCache from app.models.user import User import logging logger = logging.getLogger(__name__) # HTTP Bearer 认证方案 security = HTTPBearer() security_optional = HTTPBearer(auto_error=False) async def get_current_user( request: Request, credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db) ) -> User: """ 获取当前登录用户(依赖注入) """ token = credentials.credentials logger.info(f"Received token: {token[:20]}...") # 只记录前20个字符 # 保存 token 到请求状态,供退出登录使用 request.state.token = token # 先验证 Redis 中是否存在该 token user_id_from_redis = await TokenCache.get_user_id(token) if user_id_from_redis is None: logger.error("Token not found in Redis or expired") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已过期,请重新登录", headers={"WWW-Authenticate": "Bearer"}, ) # 解码 JWT 验证完整性 payload = decode_access_token(token) logger.info(f"Decoded payload: {payload}") if payload is None: logger.error("Token decode failed: payload is None") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", headers={"WWW-Authenticate": "Bearer"}, ) user_id_str = payload.get("sub") logger.info(f"Extracted user_id (string): {user_id_str}") if user_id_str is None: logger.error("user_id is None in payload") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", ) # 将字符串转为整数 try: user_id = int(user_id_str) except (ValueError, TypeError): logger.error(f"Invalid user_id format: {user_id_str}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", ) # 验证 Redis 中的 user_id 与 JWT 中的是否一致 if user_id != user_id_from_redis: logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", ) # 查询用户 result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: logger.error(f"User not found for user_id: {user_id}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在", ) if user.status != 1: logger.error(f"User {user_id} is disabled") raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用", ) logger.info(f"User authenticated successfully: {user.username}") return user async def get_current_user_optional( credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional), db: AsyncSession = Depends(get_db) ) -> Optional[User]: """ 获取当前登录用户(可选,不强制登录) 如果未提供token或token无效,返回None """ if not credentials: return None token = credentials.credentials try: # 验证 Redis 中是否存在该 token user_id_from_redis = await TokenCache.get_user_id(token) if user_id_from_redis is None: return None # 解码 JWT 验证完整性 payload = decode_access_token(token) if payload is None: return None user_id_str = payload.get("sub") if user_id_str is None: return None # 将字符串转为整数 try: user_id = int(user_id_str) except (ValueError, TypeError): return None # 验证 Redis 中的 user_id 与 JWT 中的是否一致 if user_id != user_id_from_redis: return None # 查询用户 result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or user.status != 1: return None return user except Exception as e: logger.warning(f"Optional auth failed: {str(e)}") return None async def get_current_active_user( current_user: User = Depends(get_current_user) ) -> User: """ 获取当前活跃用户 """ return current_user async def get_user_from_token_or_query( request: Request, token: Optional[str] = None, # 从query参数获取 credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional), db: AsyncSession = Depends(get_db) ) -> User: """ 获取当前用户(支持从query参数或header获取token) 用于图片等资源访问,优先使用header,其次使用query参数 """ # 优先从 header 获取 token if credentials: token_str = credentials.credentials elif token: # 从 query 参数获取 token_str = token else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证凭证", headers={"WWW-Authenticate": "Bearer"}, ) logger.info(f"Received token: {token_str[:20]}...") # 保存 token 到请求状态 request.state.token = token_str # 验证 Redis 中是否存在该 token user_id_from_redis = await TokenCache.get_user_id(token_str) if user_id_from_redis is None: logger.error("Token not found in Redis or expired") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已过期,请重新登录", headers={"WWW-Authenticate": "Bearer"}, ) # 解码 JWT 验证完整性 payload = decode_access_token(token_str) logger.info(f"Decoded payload: {payload}") if payload is None: logger.error("Token decode failed: payload is None") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", headers={"WWW-Authenticate": "Bearer"}, ) user_id_str = payload.get("sub") logger.info(f"Extracted user_id (string): {user_id_str}") if user_id_str is None: logger.error("user_id is None in payload") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", ) # 将字符串转为整数 try: user_id = int(user_id_str) except (ValueError, TypeError): logger.error(f"Invalid user_id format: {user_id_str}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", ) # 验证 Redis 中的 user_id 与 JWT 中的是否一致 if user_id != user_id_from_redis: logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", ) # 查询用户 result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: logger.error(f"User not found for user_id: {user_id}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在", ) if user.status != 1: logger.error(f"User {user_id} is disabled") raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用", ) logger.info(f"User authenticated successfully: {user.username}") return user