diff --git a/.DS_Store b/.DS_Store index ef921b4..2b4cb54 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/app.zip b/app.zip deleted file mode 100644 index 47f81f0..0000000 Binary files a/app.zip and /dev/null differ diff --git a/app/api/endpoints/knowledge_base.py b/app/api/endpoints/knowledge_base.py index 9bc9154..93d129d 100644 --- a/app/api/endpoints/knowledge_base.py +++ b/app/api/endpoints/knowledge_base.py @@ -9,29 +9,25 @@ import datetime router = APIRouter() -def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]: +def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]: + """ + 处理标签:查询已存在的标签,如果提供了 creator_id 则创建不存在的标签 + """ if not tag_string: return [] tag_names = [name.strip() for name in tag_string.split(',') if name.strip()] if not tag_names: return [] - - placeholders = ','.join(['%s'] * len(tag_names)) - select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})" - cursor.execute(select_query, tuple(tag_names)) + + # 如果提供了 creator_id,则创建不存在的标签 + if creator_id: + insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)" + cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names]) + + # 查询所有标签信息 + format_strings = ', '.join(['%s'] * len(tag_names)) + cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names)) tags_data = cursor.fetchall() - - existing_tags = {tag['name']: tag for tag in tags_data} - new_tags = [name for name in tag_names if name not in existing_tags] - - if new_tags: - insert_query = "INSERT INTO tags (name) VALUES (%s)" - cursor.executemany(insert_query, [(name,) for name in new_tags]) - - # Re-fetch all tags to get their IDs and default colors - cursor.execute(select_query, tuple(tag_names)) - tags_data = cursor.fetchall() - return [Tag(**tag) for tag in tags_data] @@ -83,7 +79,8 @@ def get_knowledge_bases( kb_list = [] for kb_data in kbs_data: - kb_data['tags'] = _process_tags(cursor, kb_data.get('tags')) + # 列表页不需要处理 tags,直接使用字符串 + # kb_data['tags'] 保持原样(逗号分隔的标签名称字符串) # Count source meetings - filter empty strings if kb_data.get('source_meeting_ids'): meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()] @@ -122,7 +119,7 @@ def create_knowledge_base( request.is_shared, request.source_meeting_ids, request.user_prompt, - request.tags, + request.tags, # 创建时 tags 应该为 None 或空字符串 now, now )) @@ -136,7 +133,7 @@ def create_knowledge_base( source_meeting_ids=request.source_meeting_ids, cursor=cursor ) - + connection.commit() # Add the background task to process the knowledge base generation @@ -171,7 +168,8 @@ def get_knowledge_base_detail( if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']: raise HTTPException(status_code=403, detail="Access denied") - # Process tags + # Process tags - 获取标签的完整信息(包括颜色) + # 详情页不需要创建新标签,所以不传 creator_id kb_data['tags'] = _process_tags(cursor, kb_data.get('tags')) # Get source meetings details @@ -220,6 +218,10 @@ def update_knowledge_base( if kb['creator_id'] != current_user['user_id']: raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base") + # 使用 _process_tags 处理标签(会自动创建新标签) + if request.tags: + _process_tags(cursor, request.tags, current_user['user_id']) + # Update the knowledge base now = datetime.datetime.utcnow() update_query = """ diff --git a/app/api/endpoints/meetings.py b/app/api/endpoints/meetings.py index 9ef76b2..ff632bf 100644 --- a/app/api/endpoints/meetings.py +++ b/app/api/endpoints/meetings.py @@ -23,14 +23,22 @@ transcription_service = AsyncTranscriptionService() class GenerateSummaryRequest(BaseModel): user_prompt: Optional[str] = "" -def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]: +def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]: + """ + 处理标签:查询已存在的标签,如果提供了 creator_id 则创建不存在的标签 + """ if not tag_string: return [] tag_names = [name.strip() for name in tag_string.split(',') if name.strip()] if not tag_names: return [] - insert_ignore_query = "INSERT IGNORE INTO tags (name) VALUES (%s)" - cursor.executemany(insert_ignore_query, [(name,) for name in tag_names]) + + # 如果提供了 creator_id,则创建不存在的标签 + if creator_id: + insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)" + cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names]) + + # 查询所有标签信息 format_strings = ', '.join(['%s'] * len(tag_names)) cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names)) tags_data = cursor.fetchall() @@ -125,10 +133,9 @@ def get_meeting_transcript(meeting_id: int, current_user: dict = Depends(get_cur def create_meeting(meeting_request: CreateMeetingRequest, current_user: dict = Depends(get_current_user)): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) + # 使用 _process_tags 来处理标签创建 if meeting_request.tags: - tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()] - if tag_names: - cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names]) + _process_tags(cursor, meeting_request.tags, current_user['user_id']) meeting_query = 'INSERT INTO meetings (user_id, title, meeting_time, summary, tags, created_at) VALUES (%s, %s, %s, %s, %s, %s)' cursor.execute(meeting_query, (meeting_request.user_id, meeting_request.title, meeting_request.meeting_time, None, meeting_request.tags, datetime.now().isoformat())) meeting_id = cursor.lastrowid @@ -147,10 +154,9 @@ def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest, curre return create_api_response(code="404", message="Meeting not found") if meeting['user_id'] != current_user['user_id']: return create_api_response(code="403", message="Permission denied") + # 使用 _process_tags 来处理标签创建 if meeting_request.tags: - tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()] - if tag_names: - cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names]) + _process_tags(cursor, meeting_request.tags, current_user['user_id']) update_query = 'UPDATE meetings SET title = %s, meeting_time = %s, summary = %s, tags = %s WHERE meeting_id = %s' cursor.execute(update_query, (meeting_request.title, meeting_request.meeting_time, meeting_request.summary, meeting_request.tags, meeting_id)) cursor.execute("DELETE FROM attendees WHERE meeting_id = %s", (meeting_id,)) diff --git a/app/api/endpoints/voiceprint.py b/app/api/endpoints/voiceprint.py new file mode 100644 index 0000000..b7e333b --- /dev/null +++ b/app/api/endpoints/voiceprint.py @@ -0,0 +1,131 @@ +""" +声纹采集API接口 +""" +from fastapi import APIRouter, Depends, UploadFile, File, HTTPException +from typing import Optional +from pathlib import Path +import datetime + +from app.models.models import VoiceprintStatus, VoiceprintTemplate +from app.core.auth import get_current_user +from app.core.response import create_api_response +from app.services.voiceprint_service import voiceprint_service +import app.core.config as config_module + +router = APIRouter() + + +@router.get("/voiceprint/template", response_model=None) +def get_voiceprint_template(current_user: dict = Depends(get_current_user)): + """ + 获取声纹采集朗读模板配置 + + 权限:需要登录 + """ + try: + template_data = VoiceprintTemplate(**config_module.VOICEPRINT_CONFIG) + return create_api_response(code="200", message="获取朗读模板成功", data=template_data.dict()) + except Exception as e: + return create_api_response(code="500", message=f"获取朗读模板失败: {str(e)}") + + +@router.get("/voiceprint/{user_id}", response_model=None) +def get_voiceprint_status(user_id: int, current_user: dict = Depends(get_current_user)): + """ + 获取用户声纹采集状态 + + 权限:用户只能查询自己的声纹状态,管理员可查询所有 + """ + # 权限检查:只能查询自己的声纹,或者是管理员 + if current_user['user_id'] != user_id and current_user['role_id'] != 1: + return create_api_response(code="403", message="无权限查询其他用户的声纹状态") + + try: + status_data = voiceprint_service.get_user_voiceprint_status(user_id) + return create_api_response(code="200", message="获取声纹状态成功", data=status_data) + except Exception as e: + return create_api_response(code="500", message=f"获取声纹状态失败: {str(e)}") + + +@router.post("/voiceprint/{user_id}", response_model=None) +async def upload_voiceprint( + user_id: int, + audio_file: UploadFile = File(...), + current_user: dict = Depends(get_current_user) +): + """ + 上传声纹音频文件(同步处理) + + 权限:用户只能上传自己的声纹,管理员可操作所有 + """ + # 权限检查 + if current_user['user_id'] != user_id and current_user['role_id'] != 1: + return create_api_response(code="403", message="无权限上传其他用户的声纹") + + # 检查文件格式 + file_ext = Path(audio_file.filename).suffix.lower() + if file_ext not in config_module.ALLOWED_VOICEPRINT_EXTENSIONS: + return create_api_response( + code="400", + message=f"不支持的文件格式,仅支持: {', '.join(config_module.ALLOWED_VOICEPRINT_EXTENSIONS)}" + ) + + # 检查文件大小 + max_size = config_module.VOICEPRINT_CONFIG.get('max_file_size', 5242880) # 默认5MB + content = await audio_file.read() + file_size = len(content) + + if file_size > max_size: + return create_api_response( + code="400", + message=f"文件过大,最大允许 {max_size / 1024 / 1024:.1f}MB" + ) + + try: + # 确保用户目录存在 + user_voiceprint_dir = config_module.VOICEPRINT_DIR / str(user_id) + user_voiceprint_dir.mkdir(parents=True, exist_ok=True) + + # 生成文件名:时间戳.wav + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{timestamp}.wav" + file_path = user_voiceprint_dir / filename + + # 保存文件 + with open(file_path, "wb") as f: + f.write(content) + + # 调用服务处理声纹(提取特征向量,保存到数据库) + result = voiceprint_service.save_voiceprint(user_id, str(file_path), file_size) + + return create_api_response(code="200", message="声纹采集成功", data=result) + + except Exception as e: + # 如果出错,删除已上传的文件 + if 'file_path' in locals() and Path(file_path).exists(): + Path(file_path).unlink() + + return create_api_response(code="500", message=f"声纹采集失败: {str(e)}") + + +@router.delete("/voiceprint/{user_id}", response_model=None) +def delete_voiceprint(user_id: int, current_user: dict = Depends(get_current_user)): + """ + 删除用户声纹数据,允许重新采集 + + 权限:用户只能删除自己的声纹,管理员可操作所有 + """ + # 权限检查 + if current_user['user_id'] != user_id and current_user['role_id'] != 1: + return create_api_response(code="403", message="无权限删除其他用户的声纹") + + try: + success = voiceprint_service.delete_voiceprint(user_id) + + if success: + return create_api_response(code="200", message="声纹删除成功") + else: + return create_api_response(code="404", message="未找到该用户的声纹数据") + + except Exception as e: + return create_api_response(code="500", message=f"删除声纹失败: {str(e)}") diff --git a/app/core/config.py b/app/core/config.py index 1a9f5fb..d50d6f9 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,4 +1,5 @@ import os +import json from pathlib import Path # 基础路径配置 @@ -6,10 +7,12 @@ BASE_DIR = Path(__file__).parent.parent.parent UPLOAD_DIR = BASE_DIR / "uploads" AUDIO_DIR = UPLOAD_DIR / "audio" MARKDOWN_DIR = UPLOAD_DIR / "markdown" +VOICEPRINT_DIR = UPLOAD_DIR / "voiceprint" # 文件上传配置 ALLOWED_EXTENSIONS = {".mp3", ".wav", ".m4a", ".mpeg", ".mp4"} ALLOWED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp"} +ALLOWED_VOICEPRINT_EXTENSIONS = {".wav"} MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB @@ -17,6 +20,7 @@ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB UPLOAD_DIR.mkdir(exist_ok=True) AUDIO_DIR.mkdir(exist_ok=True) MARKDOWN_DIR.mkdir(exist_ok=True) +VOICEPRINT_DIR.mkdir(exist_ok=True) # 数据库配置 DATABASE_CONFIG = { @@ -82,3 +86,12 @@ LLM_CONFIG = { # 密码重置配置 DEFAULT_RESET_PASSWORD = os.getenv('DEFAULT_RESET_PASSWORD', '111111') + +# 加载系统配置文件 +# 默认声纹配置 +VOICEPRINT_CONFIG = { + "template_text": "我正在进行声纹采集,这段语音将用于身份识别和验证。\n\n声纹技术能够准确识别每个人独特的声音特征。", + "duration_seconds": 12, + "sample_rate": 16000, + "channels": 1 +} diff --git a/app/models/models.py b/app/models/models.py index 0df7182..bfd9b86 100644 --- a/app/models/models.py +++ b/app/models/models.py @@ -128,7 +128,7 @@ class KnowledgeBase(BaseModel): is_shared: bool source_meeting_ids: Optional[str] = None user_prompt: Optional[str] = None - tags: Optional[List[Tag]] = [] + tags: Union[Optional[str], Optional[List[Tag]]] = None # 支持字符串或Tag列表 created_at: datetime.datetime updated_at: datetime.datetime source_meeting_count: Optional[int] = 0 @@ -204,3 +204,26 @@ class UpdateClientDownloadRequest(BaseModel): class ClientDownloadListResponse(BaseModel): clients: List[ClientDownload] total: int + +# 声纹采集相关模型 +class VoiceprintInfo(BaseModel): + vp_id: int + user_id: int + file_path: str + file_size: Optional[int] = None + duration_seconds: Optional[float] = None + collected_at: datetime.datetime + updated_at: datetime.datetime + +class VoiceprintStatus(BaseModel): + has_voiceprint: bool + vp_id: Optional[int] = None + file_path: Optional[str] = None + duration_seconds: Optional[float] = None + collected_at: Optional[datetime.datetime] = None + +class VoiceprintTemplate(BaseModel): + template_text: str + duration_seconds: int + sample_rate: int + channels: int diff --git a/app/services/voiceprint_service.py b/app/services/voiceprint_service.py new file mode 100644 index 0000000..13b04f7 --- /dev/null +++ b/app/services/voiceprint_service.py @@ -0,0 +1,218 @@ +""" +声纹服务 - 处理用户声纹采集、存储和验证 +""" +import os +import json +import wave +from datetime import datetime +from typing import Optional, Dict +from pathlib import Path + +from app.core.database import get_db_connection +import app.core.config as config_module + + +class VoiceprintService: + """声纹服务类 - 同步处理声纹采集""" + + def __init__(self): + self.voiceprint_dir = config_module.VOICEPRINT_DIR + self.voiceprint_config = config_module.VOICEPRINT_CONFIG + + def get_user_voiceprint_status(self, user_id: int) -> Dict: + """ + 获取用户声纹状态 + + Args: + user_id: 用户ID + + Returns: + Dict: 声纹状态信息 + """ + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + query = """ + SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at + FROM user_voiceprint + WHERE user_id = %s + """ + cursor.execute(query, (user_id,)) + voiceprint = cursor.fetchone() + + if voiceprint: + return { + "has_voiceprint": True, + "vp_id": voiceprint['vp_id'], + "file_path": voiceprint['file_path'], + "duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None, + "collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None + } + else: + return { + "has_voiceprint": False, + "vp_id": None, + "file_path": None, + "duration_seconds": None, + "collected_at": None + } + except Exception as e: + print(f"获取声纹状态错误: {e}") + raise e + + def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict: + """ + 保存声纹文件并提取特征向量 + + Args: + user_id: 用户ID + audio_file_path: 音频文件路径 + file_size: 文件大小 + + Returns: + Dict: 保存结果 + """ + try: + # 1. 获取音频时长 + duration = self._get_audio_duration(audio_file_path) + + # 2. 提取声纹向量(调用FunASR) + vector_data = self._extract_voiceprint_vector(audio_file_path) + + # 3. 保存到数据库 + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 检查用户是否已有声纹 + cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,)) + existing = cursor.fetchone() + + # 计算相对路径 + relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR)) + + if existing: + # 更新现有记录 + update_query = """ + UPDATE user_voiceprint + SET file_path = %s, file_size = %s, duration_seconds = %s, + vector_data = %s, updated_at = NOW() + WHERE user_id = %s + """ + cursor.execute(update_query, ( + relative_path, file_size, duration, + json.dumps(vector_data) if vector_data else None, + user_id + )) + vp_id = existing['vp_id'] + else: + # 插入新记录 + insert_query = """ + INSERT INTO user_voiceprint + (user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at) + VALUES (%s, %s, %s, %s, %s, NOW(), NOW()) + """ + cursor.execute(insert_query, ( + user_id, relative_path, file_size, duration, + json.dumps(vector_data) if vector_data else None + )) + vp_id = cursor.lastrowid + + connection.commit() + + return { + "vp_id": vp_id, + "user_id": user_id, + "file_path": relative_path, + "file_size": file_size, + "duration_seconds": duration, + "has_vector": vector_data is not None + } + + except Exception as e: + print(f"保存声纹错误: {e}") + raise e + + def delete_voiceprint(self, user_id: int) -> bool: + """ + 删除用户声纹 + + Args: + user_id: 用户ID + + Returns: + bool: 是否删除成功 + """ + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 获取文件路径 + cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,)) + voiceprint = cursor.fetchone() + + if voiceprint: + # 构建完整文件路径 + relative_path = voiceprint['file_path'] + if relative_path.startswith('/'): + relative_path = relative_path.lstrip('/') + file_path = config_module.BASE_DIR / relative_path + + # 删除数据库记录 + cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,)) + connection.commit() + + # 删除文件 + if file_path.exists(): + os.remove(file_path) + + return True + else: + return False + + except Exception as e: + print(f"删除声纹错误: {e}") + raise e + + def _get_audio_duration(self, audio_file_path: str) -> float: + """ + 获取音频文件时长 + + Args: + audio_file_path: 音频文件路径 + + Returns: + float: 时长(秒) + """ + try: + with wave.open(audio_file_path, 'rb') as wav_file: + frames = wav_file.getnframes() + rate = wav_file.getframerate() + duration = frames / float(rate) + return round(duration, 2) + except Exception as e: + print(f"获取音频时长错误: {e}") + return 10.0 # 默认返回10秒 + + def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]: + """ + 提取声纹特征向量(调用FunASR) + + Args: + audio_file_path: 音频文件路径 + + Returns: + Optional[list]: 声纹向量(192维),失败返回None + """ + # TODO: 集成FunASR的说话人识别模型 + # 使用 speech_campplus_sv_zh-cn_16k-common 模型 + # 返回192维的embedding向量 + + print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}") + + # 暂时返回None,等待FunASR集成 + # 集成后应该返回类似: [0.123, -0.456, 0.789, ...] + return None + + +# 创建全局实例 +voiceprint_service = VoiceprintService() diff --git a/config/system_config.json b/config/system_config.json index 9cec9a3..92bdf2a 100644 --- a/config/system_config.json +++ b/config/system_config.json @@ -1,7 +1,7 @@ { "model_name": "qwen-plus", "system_prompt": "你是一个专业的会议记录分析助手。请根据提供的会议转录内容,生成简洁明了的会议总结。\n\n总结包括五个部分(名称严格一致,生成为MD二级目录):\n1. 会议概述 - 简要说明会议的主要目的和背景(生成MD引用)\n2. 主要讨论点 - 列出会议中讨论的重要话题和内容\n3. 决策事项 - 明确记录会议中做出的决定和结论\n4. 待办事项 - 列出需要后续跟进的任务和责任人\n5. 关键信息 - 其他重要的信息点\n\n输出要求:\n- 保持客观中性,不添加个人观点\n- 使用简洁、准确的中文表达\n- 按重要性排序各项内容\n- 如果某个部分没有相关内容,可以说明\"无相关内容\"\n- 总字数控制在500字以内", - "DEFAULT_RESET_PASSWORD": "111111", - "MAX_FILE_SIZE": 209715200, + "DEFAULT_RESET_PASSWORD": "123456", + "MAX_FILE_SIZE": 208666624, "MAX_IMAGE_SIZE": 10485760 } \ No newline at end of file diff --git a/main.py b/main.py index 22836e0..2e00e33 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ import uvicorn from fastapi import FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads +from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads, voiceprint from app.core.config import UPLOAD_DIR, API_CONFIG from app.api.endpoints.admin import load_system_config import os @@ -39,6 +39,7 @@ app.include_router(tasks.router, prefix="/api", tags=["Tasks"]) app.include_router(prompts.router, prefix="/api", tags=["Prompts"]) app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"]) app.include_router(client_downloads.router, prefix="/api/clients", tags=["ClientDownloads"]) +app.include_router(voiceprint.router, prefix="/api", tags=["Voiceprint"]) @app.get("/") def read_root(): diff --git a/migrations/add_meetings_fields.sql b/migrations/add_meetings_fields.sql deleted file mode 100644 index af532f7..0000000 --- a/migrations/add_meetings_fields.sql +++ /dev/null @@ -1,17 +0,0 @@ --- 为meetings表添加updated_at和user_prompt字段 --- 执行日期: 2025-10-28 - --- 添加updated_at字段 -ALTER TABLE meetings -ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP -AFTER created_at; - --- 添加user_prompt字段 -ALTER TABLE meetings -ADD COLUMN user_prompt TEXT -AFTER summary; - --- 为现有记录设置updated_at为created_at的值 -UPDATE meetings -SET updated_at = created_at -WHERE updated_at IS NULL; diff --git a/test_voiceprint_api.py b/test_voiceprint_api.py new file mode 100644 index 0000000..eebfa19 --- /dev/null +++ b/test_voiceprint_api.py @@ -0,0 +1,141 @@ +""" +声纹采集API测试脚本 + +使用方法: +1. 确保后端服务正在运行 +2. 修改 USER_ID 和 TOKEN 为实际值 +3. 准备一个10秒的WAV音频文件 +4. 运行: python test_voiceprint_api.py +""" + +import requests +import json + +# 配置 +BASE_URL = "http://localhost:8000/api" +USER_ID = 1 # 修改为实际用户ID +TOKEN = "" # 登录后获取的token + +# 请求头 +headers = { + "Authorization": f"Bearer {TOKEN}", + "Content-Type": "application/json" +} + + +def test_get_template(): + """测试获取朗读模板""" + print("\n=== 测试1: 获取朗读模板 ===") + url = f"{BASE_URL}/voiceprint/template" + response = requests.get(url, headers=headers) + print(f"状态码: {response.status_code}") + print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}") + return response.json() + + +def test_get_status(user_id): + """测试获取声纹状态""" + print(f"\n=== 测试2: 获取用户 {user_id} 的声纹状态 ===") + url = f"{BASE_URL}/voiceprint/{user_id}" + response = requests.get(url, headers=headers) + print(f"状态码: {response.status_code}") + print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}") + return response.json() + + +def test_upload_voiceprint(user_id, audio_file_path): + """测试上传声纹""" + print(f"\n=== 测试3: 上传声纹音频 ===") + url = f"{BASE_URL}/voiceprint/{user_id}" + + # 移除Content-Type,让requests自动设置multipart/form-data + upload_headers = { + "Authorization": f"Bearer {TOKEN}" + } + + with open(audio_file_path, 'rb') as f: + files = {'audio_file': (audio_file_path.split('/')[-1], f, 'audio/wav')} + response = requests.post(url, headers=upload_headers, files=files) + + print(f"状态码: {response.status_code}") + print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}") + return response.json() + + +def test_delete_voiceprint(user_id): + """测试删除声纹""" + print(f"\n=== 测试4: 删除用户 {user_id} 的声纹 ===") + url = f"{BASE_URL}/voiceprint/{user_id}" + response = requests.delete(url, headers=headers) + print(f"状态码: {response.status_code}") + print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}") + return response.json() + + +def login(username, password): + """登录获取token""" + print("\n=== 登录获取Token ===") + url = f"{BASE_URL}/auth/login" + data = { + "username": username, + "password": password + } + response = requests.post(url, json=data) + if response.status_code == 200: + result = response.json() + if result.get('code') == '200': + token = result['data']['token'] + print(f"登录成功,Token: {token[:20]}...") + return token + else: + print(f"登录失败: {result.get('message')}") + return None + else: + print(f"请求失败,状态码: {response.status_code}") + return None + + +if __name__ == "__main__": + print("=" * 60) + print("声纹采集API测试脚本") + print("=" * 60) + + # 步骤1: 登录(如果没有token) + if not TOKEN: + print("\n请先登录获取Token...") + username = input("用户名: ") + password = input("密码: ") + TOKEN = login(username, password) + if TOKEN: + headers["Authorization"] = f"Bearer {TOKEN}" + else: + print("登录失败,退出测试") + exit(1) + + # 步骤2: 测试获取朗读模板 + test_get_template() + + # 步骤3: 测试获取声纹状态 + test_get_status(USER_ID) + + # 步骤4: 测试上传声纹(需要准备音频文件) + audio_file = input("\n请输入WAV音频文件路径 (回车跳过上传测试): ") + if audio_file.strip(): + test_upload_voiceprint(USER_ID, audio_file.strip()) + + # 上传后再次查看状态 + print("\n=== 上传后再次查看状态 ===") + test_get_status(USER_ID) + + # 步骤5: 测试删除声纹 + confirm = input("\n是否测试删除声纹? (yes/no): ") + if confirm.lower() == 'yes': + test_delete_voiceprint(USER_ID) + + # 删除后再次查看状态 + print("\n=== 删除后再次查看状态 ===") + test_get_status(USER_ID) + + print("\n" + "=" * 60) + print("测试完成") + print("=" * 60)