main
mula.liu 2025-10-31 14:54:54 +08:00
parent 976ea854b6
commit 2f36474f4d
12 changed files with 569 additions and 51 deletions

BIN
.DS_Store vendored

Binary file not shown.

BIN
app.zip

Binary file not shown.

View File

@ -9,29 +9,25 @@ import datetime
router = APIRouter()
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
"""
处理标签查询已存在的标签如果提供了 creator_id 则创建不存在的标签
"""
if not tag_string:
return []
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
if not tag_names:
return []
placeholders = ','.join(['%s'] * len(tag_names))
select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})"
cursor.execute(select_query, tuple(tag_names))
# 如果提供了 creator_id则创建不存在的标签
if creator_id:
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
# 查询所有标签信息
format_strings = ', '.join(['%s'] * len(tag_names))
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
tags_data = cursor.fetchall()
existing_tags = {tag['name']: tag for tag in tags_data}
new_tags = [name for name in tag_names if name not in existing_tags]
if new_tags:
insert_query = "INSERT INTO tags (name) VALUES (%s)"
cursor.executemany(insert_query, [(name,) for name in new_tags])
# Re-fetch all tags to get their IDs and default colors
cursor.execute(select_query, tuple(tag_names))
tags_data = cursor.fetchall()
return [Tag(**tag) for tag in tags_data]
@ -83,7 +79,8 @@ def get_knowledge_bases(
kb_list = []
for kb_data in kbs_data:
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# 列表页不需要处理 tags直接使用字符串
# kb_data['tags'] 保持原样(逗号分隔的标签名称字符串)
# Count source meetings - filter empty strings
if kb_data.get('source_meeting_ids'):
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
@ -122,7 +119,7 @@ def create_knowledge_base(
request.is_shared,
request.source_meeting_ids,
request.user_prompt,
request.tags,
request.tags, # 创建时 tags 应该为 None 或空字符串
now,
now
))
@ -136,7 +133,7 @@ def create_knowledge_base(
source_meeting_ids=request.source_meeting_ids,
cursor=cursor
)
connection.commit()
# Add the background task to process the knowledge base generation
@ -171,7 +168,8 @@ def get_knowledge_base_detail(
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Access denied")
# Process tags
# Process tags - 获取标签的完整信息(包括颜色)
# 详情页不需要创建新标签,所以不传 creator_id
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# Get source meetings details
@ -220,6 +218,10 @@ def update_knowledge_base(
if kb['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
# 使用 _process_tags 处理标签(会自动创建新标签)
if request.tags:
_process_tags(cursor, request.tags, current_user['user_id'])
# Update the knowledge base
now = datetime.datetime.utcnow()
update_query = """

View File

@ -23,14 +23,22 @@ transcription_service = AsyncTranscriptionService()
class GenerateSummaryRequest(BaseModel):
user_prompt: Optional[str] = ""
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
"""
处理标签查询已存在的标签如果提供了 creator_id 则创建不存在的标签
"""
if not tag_string:
return []
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
if not tag_names:
return []
insert_ignore_query = "INSERT IGNORE INTO tags (name) VALUES (%s)"
cursor.executemany(insert_ignore_query, [(name,) for name in tag_names])
# 如果提供了 creator_id则创建不存在的标签
if creator_id:
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
# 查询所有标签信息
format_strings = ', '.join(['%s'] * len(tag_names))
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
tags_data = cursor.fetchall()
@ -125,10 +133,9 @@ def get_meeting_transcript(meeting_id: int, current_user: dict = Depends(get_cur
def create_meeting(meeting_request: CreateMeetingRequest, current_user: dict = Depends(get_current_user)):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 使用 _process_tags 来处理标签创建
if meeting_request.tags:
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()]
if tag_names:
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
_process_tags(cursor, meeting_request.tags, current_user['user_id'])
meeting_query = 'INSERT INTO meetings (user_id, title, meeting_time, summary, tags, created_at) VALUES (%s, %s, %s, %s, %s, %s)'
cursor.execute(meeting_query, (meeting_request.user_id, meeting_request.title, meeting_request.meeting_time, None, meeting_request.tags, datetime.now().isoformat()))
meeting_id = cursor.lastrowid
@ -147,10 +154,9 @@ def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest, curre
return create_api_response(code="404", message="Meeting not found")
if meeting['user_id'] != current_user['user_id']:
return create_api_response(code="403", message="Permission denied")
# 使用 _process_tags 来处理标签创建
if meeting_request.tags:
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()]
if tag_names:
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
_process_tags(cursor, meeting_request.tags, current_user['user_id'])
update_query = 'UPDATE meetings SET title = %s, meeting_time = %s, summary = %s, tags = %s WHERE meeting_id = %s'
cursor.execute(update_query, (meeting_request.title, meeting_request.meeting_time, meeting_request.summary, meeting_request.tags, meeting_id))
cursor.execute("DELETE FROM attendees WHERE meeting_id = %s", (meeting_id,))

View File

@ -0,0 +1,131 @@
"""
声纹采集API接口
"""
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
from typing import Optional
from pathlib import Path
import datetime
from app.models.models import VoiceprintStatus, VoiceprintTemplate
from app.core.auth import get_current_user
from app.core.response import create_api_response
from app.services.voiceprint_service import voiceprint_service
import app.core.config as config_module
router = APIRouter()
@router.get("/voiceprint/template", response_model=None)
def get_voiceprint_template(current_user: dict = Depends(get_current_user)):
"""
获取声纹采集朗读模板配置
权限需要登录
"""
try:
template_data = VoiceprintTemplate(**config_module.VOICEPRINT_CONFIG)
return create_api_response(code="200", message="获取朗读模板成功", data=template_data.dict())
except Exception as e:
return create_api_response(code="500", message=f"获取朗读模板失败: {str(e)}")
@router.get("/voiceprint/{user_id}", response_model=None)
def get_voiceprint_status(user_id: int, current_user: dict = Depends(get_current_user)):
"""
获取用户声纹采集状态
权限用户只能查询自己的声纹状态管理员可查询所有
"""
# 权限检查:只能查询自己的声纹,或者是管理员
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限查询其他用户的声纹状态")
try:
status_data = voiceprint_service.get_user_voiceprint_status(user_id)
return create_api_response(code="200", message="获取声纹状态成功", data=status_data)
except Exception as e:
return create_api_response(code="500", message=f"获取声纹状态失败: {str(e)}")
@router.post("/voiceprint/{user_id}", response_model=None)
async def upload_voiceprint(
user_id: int,
audio_file: UploadFile = File(...),
current_user: dict = Depends(get_current_user)
):
"""
上传声纹音频文件同步处理
权限用户只能上传自己的声纹管理员可操作所有
"""
# 权限检查
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限上传其他用户的声纹")
# 检查文件格式
file_ext = Path(audio_file.filename).suffix.lower()
if file_ext not in config_module.ALLOWED_VOICEPRINT_EXTENSIONS:
return create_api_response(
code="400",
message=f"不支持的文件格式,仅支持: {', '.join(config_module.ALLOWED_VOICEPRINT_EXTENSIONS)}"
)
# 检查文件大小
max_size = config_module.VOICEPRINT_CONFIG.get('max_file_size', 5242880) # 默认5MB
content = await audio_file.read()
file_size = len(content)
if file_size > max_size:
return create_api_response(
code="400",
message=f"文件过大,最大允许 {max_size / 1024 / 1024:.1f}MB"
)
try:
# 确保用户目录存在
user_voiceprint_dir = config_module.VOICEPRINT_DIR / str(user_id)
user_voiceprint_dir.mkdir(parents=True, exist_ok=True)
# 生成文件名:时间戳.wav
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{timestamp}.wav"
file_path = user_voiceprint_dir / filename
# 保存文件
with open(file_path, "wb") as f:
f.write(content)
# 调用服务处理声纹(提取特征向量,保存到数据库)
result = voiceprint_service.save_voiceprint(user_id, str(file_path), file_size)
return create_api_response(code="200", message="声纹采集成功", data=result)
except Exception as e:
# 如果出错,删除已上传的文件
if 'file_path' in locals() and Path(file_path).exists():
Path(file_path).unlink()
return create_api_response(code="500", message=f"声纹采集失败: {str(e)}")
@router.delete("/voiceprint/{user_id}", response_model=None)
def delete_voiceprint(user_id: int, current_user: dict = Depends(get_current_user)):
"""
删除用户声纹数据允许重新采集
权限用户只能删除自己的声纹管理员可操作所有
"""
# 权限检查
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限删除其他用户的声纹")
try:
success = voiceprint_service.delete_voiceprint(user_id)
if success:
return create_api_response(code="200", message="声纹删除成功")
else:
return create_api_response(code="404", message="未找到该用户的声纹数据")
except Exception as e:
return create_api_response(code="500", message=f"删除声纹失败: {str(e)}")

View File

@ -1,4 +1,5 @@
import os
import json
from pathlib import Path
# 基础路径配置
@ -6,10 +7,12 @@ BASE_DIR = Path(__file__).parent.parent.parent
UPLOAD_DIR = BASE_DIR / "uploads"
AUDIO_DIR = UPLOAD_DIR / "audio"
MARKDOWN_DIR = UPLOAD_DIR / "markdown"
VOICEPRINT_DIR = UPLOAD_DIR / "voiceprint"
# 文件上传配置
ALLOWED_EXTENSIONS = {".mp3", ".wav", ".m4a", ".mpeg", ".mp4"}
ALLOWED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
ALLOWED_VOICEPRINT_EXTENSIONS = {".wav"}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
@ -17,6 +20,7 @@ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
UPLOAD_DIR.mkdir(exist_ok=True)
AUDIO_DIR.mkdir(exist_ok=True)
MARKDOWN_DIR.mkdir(exist_ok=True)
VOICEPRINT_DIR.mkdir(exist_ok=True)
# 数据库配置
DATABASE_CONFIG = {
@ -82,3 +86,12 @@ LLM_CONFIG = {
# 密码重置配置
DEFAULT_RESET_PASSWORD = os.getenv('DEFAULT_RESET_PASSWORD', '111111')
# 加载系统配置文件
# 默认声纹配置
VOICEPRINT_CONFIG = {
"template_text": "我正在进行声纹采集,这段语音将用于身份识别和验证。\n\n声纹技术能够准确识别每个人独特的声音特征。",
"duration_seconds": 12,
"sample_rate": 16000,
"channels": 1
}

View File

@ -128,7 +128,7 @@ class KnowledgeBase(BaseModel):
is_shared: bool
source_meeting_ids: Optional[str] = None
user_prompt: Optional[str] = None
tags: Optional[List[Tag]] = []
tags: Union[Optional[str], Optional[List[Tag]]] = None # 支持字符串或Tag列表
created_at: datetime.datetime
updated_at: datetime.datetime
source_meeting_count: Optional[int] = 0
@ -204,3 +204,26 @@ class UpdateClientDownloadRequest(BaseModel):
class ClientDownloadListResponse(BaseModel):
clients: List[ClientDownload]
total: int
# 声纹采集相关模型
class VoiceprintInfo(BaseModel):
vp_id: int
user_id: int
file_path: str
file_size: Optional[int] = None
duration_seconds: Optional[float] = None
collected_at: datetime.datetime
updated_at: datetime.datetime
class VoiceprintStatus(BaseModel):
has_voiceprint: bool
vp_id: Optional[int] = None
file_path: Optional[str] = None
duration_seconds: Optional[float] = None
collected_at: Optional[datetime.datetime] = None
class VoiceprintTemplate(BaseModel):
template_text: str
duration_seconds: int
sample_rate: int
channels: int

View File

@ -0,0 +1,218 @@
"""
声纹服务 - 处理用户声纹采集存储和验证
"""
import os
import json
import wave
from datetime import datetime
from typing import Optional, Dict
from pathlib import Path
from app.core.database import get_db_connection
import app.core.config as config_module
class VoiceprintService:
"""声纹服务类 - 同步处理声纹采集"""
def __init__(self):
self.voiceprint_dir = config_module.VOICEPRINT_DIR
self.voiceprint_config = config_module.VOICEPRINT_CONFIG
def get_user_voiceprint_status(self, user_id: int) -> Dict:
"""
获取用户声纹状态
Args:
user_id: 用户ID
Returns:
Dict: 声纹状态信息
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = """
SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at
FROM user_voiceprint
WHERE user_id = %s
"""
cursor.execute(query, (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
return {
"has_voiceprint": True,
"vp_id": voiceprint['vp_id'],
"file_path": voiceprint['file_path'],
"duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None,
"collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None
}
else:
return {
"has_voiceprint": False,
"vp_id": None,
"file_path": None,
"duration_seconds": None,
"collected_at": None
}
except Exception as e:
print(f"获取声纹状态错误: {e}")
raise e
def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict:
"""
保存声纹文件并提取特征向量
Args:
user_id: 用户ID
audio_file_path: 音频文件路径
file_size: 文件大小
Returns:
Dict: 保存结果
"""
try:
# 1. 获取音频时长
duration = self._get_audio_duration(audio_file_path)
# 2. 提取声纹向量调用FunASR
vector_data = self._extract_voiceprint_vector(audio_file_path)
# 3. 保存到数据库
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 检查用户是否已有声纹
cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,))
existing = cursor.fetchone()
# 计算相对路径
relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR))
if existing:
# 更新现有记录
update_query = """
UPDATE user_voiceprint
SET file_path = %s, file_size = %s, duration_seconds = %s,
vector_data = %s, updated_at = NOW()
WHERE user_id = %s
"""
cursor.execute(update_query, (
relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None,
user_id
))
vp_id = existing['vp_id']
else:
# 插入新记录
insert_query = """
INSERT INTO user_voiceprint
(user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at)
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
"""
cursor.execute(insert_query, (
user_id, relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None
))
vp_id = cursor.lastrowid
connection.commit()
return {
"vp_id": vp_id,
"user_id": user_id,
"file_path": relative_path,
"file_size": file_size,
"duration_seconds": duration,
"has_vector": vector_data is not None
}
except Exception as e:
print(f"保存声纹错误: {e}")
raise e
def delete_voiceprint(self, user_id: int) -> bool:
"""
删除用户声纹
Args:
user_id: 用户ID
Returns:
bool: 是否删除成功
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取文件路径
cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
# 构建完整文件路径
relative_path = voiceprint['file_path']
if relative_path.startswith('/'):
relative_path = relative_path.lstrip('/')
file_path = config_module.BASE_DIR / relative_path
# 删除数据库记录
cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,))
connection.commit()
# 删除文件
if file_path.exists():
os.remove(file_path)
return True
else:
return False
except Exception as e:
print(f"删除声纹错误: {e}")
raise e
def _get_audio_duration(self, audio_file_path: str) -> float:
"""
获取音频文件时长
Args:
audio_file_path: 音频文件路径
Returns:
float: 时长
"""
try:
with wave.open(audio_file_path, 'rb') as wav_file:
frames = wav_file.getnframes()
rate = wav_file.getframerate()
duration = frames / float(rate)
return round(duration, 2)
except Exception as e:
print(f"获取音频时长错误: {e}")
return 10.0 # 默认返回10秒
def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]:
"""
提取声纹特征向量调用FunASR
Args:
audio_file_path: 音频文件路径
Returns:
Optional[list]: 声纹向量192失败返回None
"""
# TODO: 集成FunASR的说话人识别模型
# 使用 speech_campplus_sv_zh-cn_16k-common 模型
# 返回192维的embedding向量
print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}")
# 暂时返回None等待FunASR集成
# 集成后应该返回类似: [0.123, -0.456, 0.789, ...]
return None
# 创建全局实例
voiceprint_service = VoiceprintService()

View File

@ -1,7 +1,7 @@
{
"model_name": "qwen-plus",
"system_prompt": "你是一个专业的会议记录分析助手。请根据提供的会议转录内容,生成简洁明了的会议总结。\n\n总结包括五个部分名称严格一致生成为MD二级目录\n1. 会议概述 - 简要说明会议的主要目的和背景(生成MD引用)\n2. 主要讨论点 - 列出会议中讨论的重要话题和内容\n3. 决策事项 - 明确记录会议中做出的决定和结论\n4. 待办事项 - 列出需要后续跟进的任务和责任人\n5. 关键信息 - 其他重要的信息点\n\n输出要求\n- 保持客观中性,不添加个人观点\n- 使用简洁、准确的中文表达\n- 按重要性排序各项内容\n- 如果某个部分没有相关内容,可以说明\"无相关内容\"\n- 总字数控制在500字以内",
"DEFAULT_RESET_PASSWORD": "111111",
"MAX_FILE_SIZE": 209715200,
"DEFAULT_RESET_PASSWORD": "123456",
"MAX_FILE_SIZE": 208666624,
"MAX_IMAGE_SIZE": 10485760
}

View File

@ -2,7 +2,7 @@ import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads, voiceprint
from app.core.config import UPLOAD_DIR, API_CONFIG
from app.api.endpoints.admin import load_system_config
import os
@ -39,6 +39,7 @@ app.include_router(tasks.router, prefix="/api", tags=["Tasks"])
app.include_router(prompts.router, prefix="/api", tags=["Prompts"])
app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"])
app.include_router(client_downloads.router, prefix="/api/clients", tags=["ClientDownloads"])
app.include_router(voiceprint.router, prefix="/api", tags=["Voiceprint"])
@app.get("/")
def read_root():

View File

@ -1,17 +0,0 @@
-- 为meetings表添加updated_at和user_prompt字段
-- 执行日期: 2025-10-28
-- 添加updated_at字段
ALTER TABLE meetings
ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
AFTER created_at;
-- 添加user_prompt字段
ALTER TABLE meetings
ADD COLUMN user_prompt TEXT
AFTER summary;
-- 为现有记录设置updated_at为created_at的值
UPDATE meetings
SET updated_at = created_at
WHERE updated_at IS NULL;

View File

@ -0,0 +1,141 @@
"""
声纹采集API测试脚本
使用方法
1. 确保后端服务正在运行
2. 修改 USER_ID TOKEN 为实际值
3. 准备一个10秒的WAV音频文件
4. 运行: python test_voiceprint_api.py
"""
import requests
import json
# 配置
BASE_URL = "http://localhost:8000/api"
USER_ID = 1 # 修改为实际用户ID
TOKEN = "" # 登录后获取的token
# 请求头
headers = {
"Authorization": f"Bearer {TOKEN}",
"Content-Type": "application/json"
}
def test_get_template():
"""测试获取朗读模板"""
print("\n=== 测试1: 获取朗读模板 ===")
url = f"{BASE_URL}/voiceprint/template"
response = requests.get(url, headers=headers)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def test_get_status(user_id):
"""测试获取声纹状态"""
print(f"\n=== 测试2: 获取用户 {user_id} 的声纹状态 ===")
url = f"{BASE_URL}/voiceprint/{user_id}"
response = requests.get(url, headers=headers)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def test_upload_voiceprint(user_id, audio_file_path):
"""测试上传声纹"""
print(f"\n=== 测试3: 上传声纹音频 ===")
url = f"{BASE_URL}/voiceprint/{user_id}"
# 移除Content-Type让requests自动设置multipart/form-data
upload_headers = {
"Authorization": f"Bearer {TOKEN}"
}
with open(audio_file_path, 'rb') as f:
files = {'audio_file': (audio_file_path.split('/')[-1], f, 'audio/wav')}
response = requests.post(url, headers=upload_headers, files=files)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def test_delete_voiceprint(user_id):
"""测试删除声纹"""
print(f"\n=== 测试4: 删除用户 {user_id} 的声纹 ===")
url = f"{BASE_URL}/voiceprint/{user_id}"
response = requests.delete(url, headers=headers)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def login(username, password):
"""登录获取token"""
print("\n=== 登录获取Token ===")
url = f"{BASE_URL}/auth/login"
data = {
"username": username,
"password": password
}
response = requests.post(url, json=data)
if response.status_code == 200:
result = response.json()
if result.get('code') == '200':
token = result['data']['token']
print(f"登录成功Token: {token[:20]}...")
return token
else:
print(f"登录失败: {result.get('message')}")
return None
else:
print(f"请求失败,状态码: {response.status_code}")
return None
if __name__ == "__main__":
print("=" * 60)
print("声纹采集API测试脚本")
print("=" * 60)
# 步骤1: 登录如果没有token
if not TOKEN:
print("\n请先登录获取Token...")
username = input("用户名: ")
password = input("密码: ")
TOKEN = login(username, password)
if TOKEN:
headers["Authorization"] = f"Bearer {TOKEN}"
else:
print("登录失败,退出测试")
exit(1)
# 步骤2: 测试获取朗读模板
test_get_template()
# 步骤3: 测试获取声纹状态
test_get_status(USER_ID)
# 步骤4: 测试上传声纹(需要准备音频文件)
audio_file = input("\n请输入WAV音频文件路径 (回车跳过上传测试): ")
if audio_file.strip():
test_upload_voiceprint(USER_ID, audio_file.strip())
# 上传后再次查看状态
print("\n=== 上传后再次查看状态 ===")
test_get_status(USER_ID)
# 步骤5: 测试删除声纹
confirm = input("\n是否测试删除声纹? (yes/no): ")
if confirm.lower() == 'yes':
test_delete_voiceprint(USER_ID)
# 删除后再次查看状态
print("\n=== 删除后再次查看状态 ===")
test_get_status(USER_ID)
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)