nex_docus/backend/app/api/v1/auth.py

290 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
用户认证相关 API
"""
from fastapi import APIRouter, Depends, HTTPException, status, Request, UploadFile, File
from fastapi.responses import FileResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from datetime import datetime
import logging
import os
import uuid
from pathlib import Path
import aiofiles
from app.core.database import get_db
from app.core.security import verify_password, get_password_hash, create_access_token
from app.core.deps import get_current_user
from app.core.redis_client import TokenCache
from app.core.config import settings
from app.models.user import User
from app.models.role import Role, UserRole
from app.schemas.user import UserCreate, UserLogin, UserResponse, Token, ChangePassword, UserUpdate
from app.schemas.response import success_response, error_response
from app.services.log_service import log_service
from app.core.enums import OperationType, ResourceType
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/register", response_model=dict)
async def register(
user_in: UserCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""用户注册"""
# 检查用户名是否存在
result = await db.execute(select(User).where(User.username == user_in.username))
existing_user = result.scalar_one_or_none()
if existing_user:
raise HTTPException(status_code=400, detail="用户名已存在")
# 检查邮箱是否存在
if user_in.email:
result = await db.execute(select(User).where(User.email == user_in.email))
existing_email = result.scalar_one_or_none()
if existing_email:
raise HTTPException(status_code=400, detail="邮箱已被注册")
# 创建用户
db_user = User(
username=user_in.username,
password_hash=get_password_hash(user_in.password),
nickname=user_in.nickname or user_in.username,
email=user_in.email,
phone=user_in.phone,
status=1,
)
db.add(db_user)
await db.commit()
await db.refresh(db_user)
# 分配默认角色(普通用户)
result = await db.execute(select(Role).where(Role.role_code == "user"))
default_role = result.scalar_one_or_none()
if default_role:
user_role = UserRole(user_id=db_user.id, role_id=default_role.id)
db.add(user_role)
await db.commit()
# 记录注册日志
await log_service.log_operation(
db=db,
operation_type=OperationType.USER_REGISTER,
resource_type=ResourceType.USER,
user=db_user,
resource_id=db_user.id,
detail={"username": db_user.username, "email": user_in.email},
request=request,
)
return success_response(
data={"user_id": db_user.id, "username": db_user.username},
message="注册成功"
)
@router.post("/login", response_model=dict)
async def login(
user_in: UserLogin,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""用户登录"""
# 查询用户
result = await db.execute(select(User).where(User.username == user_in.username))
user = result.scalar_one_or_none()
if not user or not verify_password(user_in.password, user.password_hash):
raise HTTPException(status_code=401, detail="用户名或密码错误")
if user.status != 1:
raise HTTPException(status_code=403, detail="用户已被禁用")
# 更新最后登录时间
user.last_login_at = datetime.utcnow()
await db.commit()
# 生成 Tokensub 必须是字符串)
access_token = create_access_token(data={"sub": str(user.id)})
# 保存 token 到 Redis24小时过期
await TokenCache.save_token(user.id, access_token, expire_seconds=86400)
# 记录登录日志
await log_service.log_operation(
db=db,
operation_type=OperationType.USER_LOGIN,
resource_type=ResourceType.USER,
user=user,
resource_id=user.id,
detail={"username": user.username},
request=request,
)
# 返回用户信息和 Token
user_data = UserResponse.from_orm(user)
token_data = Token(access_token=access_token, user=user_data)
return success_response(data=token_data.dict(), message="登录成功")
@router.get("/me", response_model=dict)
async def get_current_user_info(current_user: User = Depends(get_current_user)):
"""获取当前用户信息"""
user_data = UserResponse.from_orm(current_user)
return success_response(data=user_data.dict())
@router.put("/profile", response_model=dict)
async def update_profile(
profile_in: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""更新用户资料"""
# 检查邮箱是否已被其他用户使用
if profile_in.email:
result = await db.execute(
select(User).where(User.email == profile_in.email, User.id != current_user.id)
)
existing_email = result.scalar_one_or_none()
if existing_email:
raise HTTPException(status_code=400, detail="邮箱已被其他用户使用")
# 更新字段
update_data = profile_in.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(current_user, field, value)
await db.commit()
await db.refresh(current_user)
user_data = UserResponse.from_orm(current_user)
return success_response(data=user_data.dict(), message="资料更新成功")
@router.post("/change-password", response_model=dict)
async def change_password(
password_in: ChangePassword,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""修改密码"""
# 验证旧密码
if not verify_password(password_in.old_password, current_user.password_hash):
raise HTTPException(status_code=400, detail="旧密码错误")
# 更新密码
current_user.password_hash = get_password_hash(password_in.new_password)
await db.commit()
# 密码修改后,删除用户所有 token强制重新登录
await TokenCache.delete_user_all_tokens(current_user.id)
return success_response(message="密码修改成功,请重新登录")
@router.post("/logout", response_model=dict)
async def logout(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""退出登录"""
# 从请求状态中获取 token已在 get_current_user 中保存)
token = getattr(request.state, 'token', None)
if token:
await TokenCache.delete_token(token)
logger.info(f"User {current_user.username} logged out")
# 记录登出日志
await log_service.log_operation(
db=db,
operation_type=OperationType.USER_LOGOUT,
resource_type=ResourceType.USER,
user=current_user,
resource_id=current_user.id,
detail={"username": current_user.username},
request=request,
)
return success_response(message="退出成功")
@router.post("/upload-avatar", response_model=dict)
async def upload_avatar(
file: UploadFile = File(...),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""上传用户头像"""
# 验证文件类型
allowed_types = ["image/jpeg", "image/jpg", "image/png"]
if file.content_type not in allowed_types:
raise HTTPException(status_code=400, detail="仅支持 JPG、PNG 格式的图片")
# 验证文件大小
file_content = await file.read()
if len(file_content) > settings.AVATAR_MAX_SIZE:
raise HTTPException(status_code=400, detail="文件大小不能超过 1MB")
# 重置文件指针
await file.seek(0)
# 创建用户头像目录
user_avatar_dir = Path(settings.USERS_PATH) / str(current_user.id) / "avatar"
user_avatar_dir.mkdir(parents=True, exist_ok=True)
# 生成唯一文件名(使用 UUID + 原始文件扩展名)
file_ext = Path(file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_ext}"
file_path = user_avatar_dir / unique_filename
# 删除旧头像文件(如果存在)
if current_user.avatar:
old_avatar_path = Path(settings.USERS_PATH) / current_user.avatar
if old_avatar_path.exists():
try:
old_avatar_path.unlink()
except Exception as e:
logger.warning(f"Failed to delete old avatar: {e}")
# 保存文件
async with aiofiles.open(file_path, 'wb') as f:
await f.write(file_content)
# 更新用户头像字段(存储相对于 USERS_PATH 的路径,便于前端访问)
relative_path = f"{current_user.id}/avatar/{unique_filename}"
current_user.avatar = relative_path
await db.commit()
await db.refresh(current_user)
user_data = UserResponse.from_orm(current_user)
return success_response(data=user_data.dict(), message="头像上传成功")
@router.get("/avatar/{user_id}/{filename}")
async def get_avatar(
user_id: int,
filename: str
):
"""获取用户头像"""
# 构建头像文件路径
avatar_path = Path(settings.USERS_PATH) / str(user_id) / "avatar" / filename
if not avatar_path.exists() or not avatar_path.is_file():
raise HTTPException(status_code=404, detail="头像不存在")
# 返回文件
return FileResponse(
path=str(avatar_path),
media_type="image/jpeg", # 根据文件扩展名自动判断
headers={
"Cache-Control": "public, max-age=31536000, immutable"
}
)