nex_docus/backend/app/core/deps.py

266 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
认证依赖:获取当前登录用户
"""
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Optional
from app.core.database import get_db
from app.core.security import decode_access_token
from app.core.redis_client import TokenCache
from app.models.user import User
import logging
logger = logging.getLogger(__name__)
# HTTP Bearer 认证方案
security = HTTPBearer()
security_optional = HTTPBearer(auto_error=False)
async def get_current_user(
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db)
) -> User:
"""
获取当前登录用户(依赖注入)
"""
token = credentials.credentials
logger.info(f"Received token: {token[:20]}...") # 只记录前20个字符
# 保存 token 到请求状态,供退出登录使用
request.state.token = token
# 先验证 Redis 中是否存在该 token
user_id_from_redis = await TokenCache.get_user_id(token)
if user_id_from_redis is None:
logger.error("Token not found in Redis or expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="登录已过期,请重新登录",
headers={"WWW-Authenticate": "Bearer"},
)
# 解码 JWT 验证完整性
payload = decode_access_token(token)
logger.info(f"Decoded payload: {payload}")
if payload is None:
logger.error("Token decode failed: payload is None")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
user_id_str = payload.get("sub")
logger.info(f"Extracted user_id (string): {user_id_str}")
if user_id_str is None:
logger.error("user_id is None in payload")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 将字符串转为整数
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
logger.error(f"Invalid user_id format: {user_id_str}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 验证 Redis 中的 user_id 与 JWT 中的是否一致
if user_id != user_id_from_redis:
logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 查询用户
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
logger.error(f"User not found for user_id: {user_id}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在",
)
if user.status != 1:
logger.error(f"User {user_id} is disabled")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="用户已被禁用",
)
logger.info(f"User authenticated successfully: {user.username}")
return user
async def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional),
db: AsyncSession = Depends(get_db)
) -> Optional[User]:
"""
获取当前登录用户(可选,不强制登录)
如果未提供token或token无效返回None
"""
if not credentials:
return None
token = credentials.credentials
try:
# 验证 Redis 中是否存在该 token
user_id_from_redis = await TokenCache.get_user_id(token)
if user_id_from_redis is None:
return None
# 解码 JWT 验证完整性
payload = decode_access_token(token)
if payload is None:
return None
user_id_str = payload.get("sub")
if user_id_str is None:
return None
# 将字符串转为整数
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
return None
# 验证 Redis 中的 user_id 与 JWT 中的是否一致
if user_id != user_id_from_redis:
return None
# 查询用户
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or user.status != 1:
return None
return user
except Exception as e:
logger.warning(f"Optional auth failed: {str(e)}")
return None
async def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
获取当前活跃用户
"""
return current_user
async def get_user_from_token_or_query(
request: Request,
token: Optional[str] = None, # 从query参数获取
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional),
db: AsyncSession = Depends(get_db)
) -> User:
"""
获取当前用户支持从query参数或header获取token
用于图片等资源访问优先使用header其次使用query参数
"""
# 优先从 header 获取 token
if credentials:
token_str = credentials.credentials
elif token:
# 从 query 参数获取
token_str = token
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
logger.info(f"Received token: {token_str[:20]}...")
# 保存 token 到请求状态
request.state.token = token_str
# 验证 Redis 中是否存在该 token
user_id_from_redis = await TokenCache.get_user_id(token_str)
if user_id_from_redis is None:
logger.error("Token not found in Redis or expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="登录已过期,请重新登录",
headers={"WWW-Authenticate": "Bearer"},
)
# 解码 JWT 验证完整性
payload = decode_access_token(token_str)
logger.info(f"Decoded payload: {payload}")
if payload is None:
logger.error("Token decode failed: payload is None")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
user_id_str = payload.get("sub")
logger.info(f"Extracted user_id (string): {user_id_str}")
if user_id_str is None:
logger.error("user_id is None in payload")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 将字符串转为整数
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
logger.error(f"Invalid user_id format: {user_id_str}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 验证 Redis 中的 user_id 与 JWT 中的是否一致
if user_id != user_id_from_redis:
logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 查询用户
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
logger.error(f"User not found for user_id: {user_id}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在",
)
if user.status != 1:
logger.error(f"User {user_id} is disabled")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="用户已被禁用",
)
logger.info(f"User authenticated successfully: {user.username}")
return user