imetting_backend/app/api/endpoints/voiceprint.py

132 lines
4.8 KiB
Python

"""
声纹采集API接口
"""
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
from typing import Optional
from pathlib import Path
import datetime
from app.models.models import VoiceprintStatus, VoiceprintTemplate
from app.core.auth import get_current_user
from app.core.response import create_api_response
from app.services.voiceprint_service import voiceprint_service
import app.core.config as config_module
router = APIRouter()
@router.get("/voiceprint/template", response_model=None)
def get_voiceprint_template(current_user: dict = Depends(get_current_user)):
"""
获取声纹采集朗读模板配置
权限:需要登录
"""
try:
template_data = VoiceprintTemplate(**config_module.VOICEPRINT_CONFIG)
return create_api_response(code="200", message="获取朗读模板成功", data=template_data.dict())
except Exception as e:
return create_api_response(code="500", message=f"获取朗读模板失败: {str(e)}")
@router.get("/voiceprint/{user_id}", response_model=None)
def get_voiceprint_status(user_id: int, current_user: dict = Depends(get_current_user)):
"""
获取用户声纹采集状态
权限:用户只能查询自己的声纹状态,管理员可查询所有
"""
# 权限检查:只能查询自己的声纹,或者是管理员
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限查询其他用户的声纹状态")
try:
status_data = voiceprint_service.get_user_voiceprint_status(user_id)
return create_api_response(code="200", message="获取声纹状态成功", data=status_data)
except Exception as e:
return create_api_response(code="500", message=f"获取声纹状态失败: {str(e)}")
@router.post("/voiceprint/{user_id}", response_model=None)
async def upload_voiceprint(
user_id: int,
audio_file: UploadFile = File(...),
current_user: dict = Depends(get_current_user)
):
"""
上传声纹音频文件(同步处理)
权限:用户只能上传自己的声纹,管理员可操作所有
"""
# 权限检查
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限上传其他用户的声纹")
# 检查文件格式
file_ext = Path(audio_file.filename).suffix.lower()
if file_ext not in config_module.ALLOWED_VOICEPRINT_EXTENSIONS:
return create_api_response(
code="400",
message=f"不支持的文件格式,仅支持: {', '.join(config_module.ALLOWED_VOICEPRINT_EXTENSIONS)}"
)
# 检查文件大小
max_size = config_module.VOICEPRINT_CONFIG.get('max_file_size', 5242880) # 默认5MB
content = await audio_file.read()
file_size = len(content)
if file_size > max_size:
return create_api_response(
code="400",
message=f"文件过大,最大允许 {max_size / 1024 / 1024:.1f}MB"
)
try:
# 确保用户目录存在
user_voiceprint_dir = config_module.VOICEPRINT_DIR / str(user_id)
user_voiceprint_dir.mkdir(parents=True, exist_ok=True)
# 生成文件名:时间戳.wav
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{timestamp}.wav"
file_path = user_voiceprint_dir / filename
# 保存文件
with open(file_path, "wb") as f:
f.write(content)
# 调用服务处理声纹(提取特征向量,保存到数据库)
result = voiceprint_service.save_voiceprint(user_id, str(file_path), file_size)
return create_api_response(code="200", message="声纹采集成功", data=result)
except Exception as e:
# 如果出错,删除已上传的文件
if 'file_path' in locals() and Path(file_path).exists():
Path(file_path).unlink()
return create_api_response(code="500", message=f"声纹采集失败: {str(e)}")
@router.delete("/voiceprint/{user_id}", response_model=None)
def delete_voiceprint(user_id: int, current_user: dict = Depends(get_current_user)):
"""
删除用户声纹数据,允许重新采集
权限:用户只能删除自己的声纹,管理员可操作所有
"""
# 权限检查
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限删除其他用户的声纹")
try:
success = voiceprint_service.delete_voiceprint(user_id)
if success:
return create_api_response(code="200", message="声纹删除成功")
else:
return create_api_response(code="404", message="未找到该用户的声纹数据")
except Exception as e:
return create_api_response(code="500", message=f"删除声纹失败: {str(e)}")