Compare commits

...

2 Commits

Author SHA1 Message Date
mula.liu 2f36474f4d 1.0.3 2025-10-31 14:54:54 +08:00
mula.liu 976ea854b6 整理了会议和知识库的代码结构 2025-10-28 19:30:09 +08:00
19 changed files with 913 additions and 356 deletions

BIN
.DS_Store vendored

Binary file not shown.

BIN
app.zip

Binary file not shown.

View File

@ -9,29 +9,25 @@ import datetime
router = APIRouter() router = APIRouter()
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]: def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
"""
处理标签查询已存在的标签如果提供了 creator_id 则创建不存在的标签
"""
if not tag_string: if not tag_string:
return [] return []
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()] tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
if not tag_names: if not tag_names:
return [] return []
placeholders = ','.join(['%s'] * len(tag_names)) # 如果提供了 creator_id则创建不存在的标签
select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})" if creator_id:
cursor.execute(select_query, tuple(tag_names)) insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
# 查询所有标签信息
format_strings = ', '.join(['%s'] * len(tag_names))
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
tags_data = cursor.fetchall() tags_data = cursor.fetchall()
existing_tags = {tag['name']: tag for tag in tags_data}
new_tags = [name for name in tag_names if name not in existing_tags]
if new_tags:
insert_query = "INSERT INTO tags (name) VALUES (%s)"
cursor.executemany(insert_query, [(name,) for name in new_tags])
# Re-fetch all tags to get their IDs and default colors
cursor.execute(select_query, tuple(tag_names))
tags_data = cursor.fetchall()
return [Tag(**tag) for tag in tags_data] return [Tag(**tag) for tag in tags_data]
@ -83,7 +79,8 @@ def get_knowledge_bases(
kb_list = [] kb_list = []
for kb_data in kbs_data: for kb_data in kbs_data:
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags')) # 列表页不需要处理 tags直接使用字符串
# kb_data['tags'] 保持原样(逗号分隔的标签名称字符串)
# Count source meetings - filter empty strings # Count source meetings - filter empty strings
if kb_data.get('source_meeting_ids'): if kb_data.get('source_meeting_ids'):
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()] meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
@ -122,7 +119,7 @@ def create_knowledge_base(
request.is_shared, request.is_shared,
request.source_meeting_ids, request.source_meeting_ids,
request.user_prompt, request.user_prompt,
request.tags, request.tags, # 创建时 tags 应该为 None 或空字符串
now, now,
now now
)) ))
@ -136,7 +133,7 @@ def create_knowledge_base(
source_meeting_ids=request.source_meeting_ids, source_meeting_ids=request.source_meeting_ids,
cursor=cursor cursor=cursor
) )
connection.commit() connection.commit()
# Add the background task to process the knowledge base generation # Add the background task to process the knowledge base generation
@ -171,7 +168,8 @@ def get_knowledge_base_detail(
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']: if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=403, detail="Access denied")
# Process tags # Process tags - 获取标签的完整信息(包括颜色)
# 详情页不需要创建新标签,所以不传 creator_id
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags')) kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# Get source meetings details # Get source meetings details
@ -220,6 +218,10 @@ def update_knowledge_base(
if kb['creator_id'] != current_user['user_id']: if kb['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base") raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
# 使用 _process_tags 处理标签(会自动创建新标签)
if request.tags:
_process_tags(cursor, request.tags, current_user['user_id'])
# Update the knowledge base # Update the knowledge base
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
update_query = """ update_query = """

View File

@ -5,7 +5,7 @@ from app.core.config import BASE_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSION
import app.core.config as config_module import app.core.config as config_module
from app.services.llm_service import LLMService from app.services.llm_service import LLMService
from app.services.async_transcription_service import AsyncTranscriptionService from app.services.async_transcription_service import AsyncTranscriptionService
from app.services.async_llm_service import async_llm_service from app.services.async_meeting_service import async_meeting_service
from app.core.auth import get_current_user from app.core.auth import get_current_user
from app.core.response import create_api_response from app.core.response import create_api_response
from typing import List, Optional from typing import List, Optional
@ -23,14 +23,22 @@ transcription_service = AsyncTranscriptionService()
class GenerateSummaryRequest(BaseModel): class GenerateSummaryRequest(BaseModel):
user_prompt: Optional[str] = "" user_prompt: Optional[str] = ""
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]: def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
"""
处理标签查询已存在的标签如果提供了 creator_id 则创建不存在的标签
"""
if not tag_string: if not tag_string:
return [] return []
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()] tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
if not tag_names: if not tag_names:
return [] return []
insert_ignore_query = "INSERT IGNORE INTO tags (name) VALUES (%s)"
cursor.executemany(insert_ignore_query, [(name,) for name in tag_names]) # 如果提供了 creator_id则创建不存在的标签
if creator_id:
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
# 查询所有标签信息
format_strings = ', '.join(['%s'] * len(tag_names)) format_strings = ', '.join(['%s'] * len(tag_names))
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names)) cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
tags_data = cursor.fetchall() tags_data = cursor.fetchall()
@ -125,10 +133,9 @@ def get_meeting_transcript(meeting_id: int, current_user: dict = Depends(get_cur
def create_meeting(meeting_request: CreateMeetingRequest, current_user: dict = Depends(get_current_user)): def create_meeting(meeting_request: CreateMeetingRequest, current_user: dict = Depends(get_current_user)):
with get_db_connection() as connection: with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True) cursor = connection.cursor(dictionary=True)
# 使用 _process_tags 来处理标签创建
if meeting_request.tags: if meeting_request.tags:
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()] _process_tags(cursor, meeting_request.tags, current_user['user_id'])
if tag_names:
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
meeting_query = 'INSERT INTO meetings (user_id, title, meeting_time, summary, tags, created_at) VALUES (%s, %s, %s, %s, %s, %s)' meeting_query = 'INSERT INTO meetings (user_id, title, meeting_time, summary, tags, created_at) VALUES (%s, %s, %s, %s, %s, %s)'
cursor.execute(meeting_query, (meeting_request.user_id, meeting_request.title, meeting_request.meeting_time, None, meeting_request.tags, datetime.now().isoformat())) cursor.execute(meeting_query, (meeting_request.user_id, meeting_request.title, meeting_request.meeting_time, None, meeting_request.tags, datetime.now().isoformat()))
meeting_id = cursor.lastrowid meeting_id = cursor.lastrowid
@ -147,10 +154,9 @@ def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest, curre
return create_api_response(code="404", message="Meeting not found") return create_api_response(code="404", message="Meeting not found")
if meeting['user_id'] != current_user['user_id']: if meeting['user_id'] != current_user['user_id']:
return create_api_response(code="403", message="Permission denied") return create_api_response(code="403", message="Permission denied")
# 使用 _process_tags 来处理标签创建
if meeting_request.tags: if meeting_request.tags:
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()] _process_tags(cursor, meeting_request.tags, current_user['user_id'])
if tag_names:
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
update_query = 'UPDATE meetings SET title = %s, meeting_time = %s, summary = %s, tags = %s WHERE meeting_id = %s' update_query = 'UPDATE meetings SET title = %s, meeting_time = %s, summary = %s, tags = %s WHERE meeting_id = %s'
cursor.execute(update_query, (meeting_request.title, meeting_request.meeting_time, meeting_request.summary, meeting_request.tags, meeting_id)) cursor.execute(update_query, (meeting_request.title, meeting_request.meeting_time, meeting_request.summary, meeting_request.tags, meeting_id))
cursor.execute("DELETE FROM attendees WHERE meeting_id = %s", (meeting_id,)) cursor.execute("DELETE FROM attendees WHERE meeting_id = %s", (meeting_id,))
@ -449,8 +455,8 @@ def generate_meeting_summary_async(meeting_id: int, request: GenerateSummaryRequ
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,)) cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
if not cursor.fetchone(): if not cursor.fetchone():
return create_api_response(code="404", message="Meeting not found") return create_api_response(code="404", message="Meeting not found")
task_id = async_llm_service.start_summary_generation(meeting_id, request.user_prompt) task_id = async_meeting_service.start_summary_generation(meeting_id, request.user_prompt)
background_tasks.add_task(async_llm_service._process_task, task_id) background_tasks.add_task(async_meeting_service._process_task, task_id)
return create_api_response(code="200", message="Summary generation task has been accepted.", data={ return create_api_response(code="200", message="Summary generation task has been accepted.", data={
"task_id": task_id, "status": "pending", "meeting_id": meeting_id "task_id": task_id, "status": "pending", "meeting_id": meeting_id
}) })
@ -465,7 +471,7 @@ def get_meeting_llm_tasks(meeting_id: int, current_user: dict = Depends(get_curr
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,)) cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
if not cursor.fetchone(): if not cursor.fetchone():
return create_api_response(code="404", message="Meeting not found") return create_api_response(code="404", message="Meeting not found")
tasks = async_llm_service.get_meeting_llm_tasks(meeting_id) tasks = async_meeting_service.get_meeting_llm_tasks(meeting_id)
return create_api_response(code="200", message="LLM tasks retrieved successfully", data={ return create_api_response(code="200", message="LLM tasks retrieved successfully", data={
"tasks": tasks, "total": len(tasks) "tasks": tasks, "total": len(tasks)
}) })

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
from app.core.auth import get_current_admin_user from app.core.auth import get_current_user
from app.core.database import get_db_connection from app.core.database import get_db_connection
from app.core.response import create_api_response from app.core.response import create_api_response
@ -23,14 +23,14 @@ class PromptListResponse(BaseModel):
total: int total: int
@router.post("/prompts") @router.post("/prompts")
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)): def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
"""Create a new prompt.""" """Create a new prompt."""
with get_db_connection() as connection: with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True) cursor = connection.cursor(dictionary=True)
try: try:
cursor.execute( cursor.execute(
"INSERT INTO prompts (name, tags, content) VALUES (%s, %s, %s)", "INSERT INTO prompts (name, tags, content, creator_id) VALUES (%s, %s, %s, %s)",
(prompt.name, prompt.tags, prompt.content) (prompt.name, prompt.tags, prompt.content, current_user["user_id"])
) )
connection.commit() connection.commit()
new_id = cursor.lastrowid new_id = cursor.lastrowid
@ -41,23 +41,27 @@ def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_adm
return create_api_response(code="500", message=f"创建提示词失败: {e}") return create_api_response(code="500", message=f"创建提示词失败: {e}")
@router.get("/prompts") @router.get("/prompts")
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_admin_user)): def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_user)):
"""Get a paginated list of prompts.""" """Get a paginated list of prompts filtered by current user."""
with get_db_connection() as connection: with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True) cursor = connection.cursor(dictionary=True)
cursor.execute("SELECT COUNT(*) as total FROM prompts") # 只获取当前用户创建的提示词
cursor.execute(
"SELECT COUNT(*) as total FROM prompts WHERE creator_id = %s",
(current_user["user_id"],)
)
total = cursor.fetchone()['total'] total = cursor.fetchone()['total']
offset = (page - 1) * size offset = (page - 1) * size
cursor.execute( cursor.execute(
"SELECT id, name, tags, content, created_at FROM prompts ORDER BY created_at DESC LIMIT %s OFFSET %s", "SELECT id, name, tags, content, created_at FROM prompts WHERE creator_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s",
(size, offset) (current_user["user_id"], size, offset)
) )
prompts = cursor.fetchall() prompts = cursor.fetchall()
return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total}) return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total})
@router.get("/prompts/{prompt_id}") @router.get("/prompts/{prompt_id}")
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)): def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
"""Get a single prompt by its ID.""" """Get a single prompt by its ID."""
with get_db_connection() as connection: with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True) cursor = connection.cursor(dictionary=True)
@ -68,7 +72,7 @@ def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_us
return create_api_response(code="200", message="获取提示词成功", data=prompt) return create_api_response(code="200", message="获取提示词成功", data=prompt)
@router.put("/prompts/{prompt_id}") @router.put("/prompts/{prompt_id}")
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)): def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
"""Update an existing prompt.""" """Update an existing prompt."""
with get_db_connection() as connection: with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True) cursor = connection.cursor(dictionary=True)
@ -87,12 +91,24 @@ def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends
return create_api_response(code="500", message=f"更新提示词失败: {e}") return create_api_response(code="500", message=f"更新提示词失败: {e}")
@router.delete("/prompts/{prompt_id}") @router.delete("/prompts/{prompt_id}")
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)): def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
"""Delete a prompt.""" """Delete a prompt. Only the creator can delete their own prompts."""
with get_db_connection() as connection: with get_db_connection() as connection:
cursor = connection.cursor() cursor = connection.cursor(dictionary=True)
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,)) # 首先检查提示词是否存在以及是否属于当前用户
if cursor.rowcount == 0: cursor.execute(
"SELECT creator_id FROM prompts WHERE id = %s",
(prompt_id,)
)
prompt = cursor.fetchone()
if not prompt:
return create_api_response(code="404", message="提示词不存在") return create_api_response(code="404", message="提示词不存在")
if prompt['creator_id'] != current_user["user_id"]:
return create_api_response(code="403", message="无权删除其他用户的提示词")
# 删除提示词
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
connection.commit() connection.commit()
return create_api_response(code="200", message="提示词删除成功") return create_api_response(code="200", message="提示词删除成功")

View File

@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from app.core.database import get_db_connection from app.core.database import get_db_connection
from app.core.response import create_api_response from app.core.response import create_api_response
from app.core.auth import get_current_user
from app.models.models import Tag from app.models.models import Tag
from typing import List from typing import List
import mysql.connector import mysql.connector
@ -24,16 +25,16 @@ def get_all_tags():
return create_api_response(code="500", message="获取标签失败") return create_api_response(code="500", message="获取标签失败")
@router.post("/tags/") @router.post("/tags/")
def create_tag(tag_in: Tag): def create_tag(tag_in: Tag, current_user: dict = Depends(get_current_user)):
"""_summary_ """_summary_
创建一个新标签 创建一个新标签并记录创建者
""" """
query = "INSERT INTO tags (name, color) VALUES (%s, %s)" query = "INSERT INTO tags (name, color, creator_id) VALUES (%s, %s, %s)"
try: try:
with get_db_connection() as connection: with get_db_connection() as connection:
with connection.cursor(dictionary=True) as cursor: with connection.cursor(dictionary=True) as cursor:
try: try:
cursor.execute(query, (tag_in.name, tag_in.color)) cursor.execute(query, (tag_in.name, tag_in.color, current_user["user_id"]))
connection.commit() connection.commit()
tag_id = cursor.lastrowid tag_id = cursor.lastrowid
new_tag = {"id": tag_id, "name": tag_in.name, "color": tag_in.color} new_tag = {"id": tag_id, "name": tag_in.name, "color": tag_in.color}

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
from app.core.auth import get_current_user from app.core.auth import get_current_user
from app.core.response import create_api_response from app.core.response import create_api_response
from app.services.async_transcription_service import AsyncTranscriptionService from app.services.async_transcription_service import AsyncTranscriptionService
from app.services.async_llm_service import async_llm_service from app.services.async_meeting_service import async_meeting_service
router = APIRouter() router = APIRouter()
@ -23,7 +23,7 @@ def get_transcription_task_status(task_id: str, current_user: dict = Depends(get
def get_llm_task_status(task_id: str, current_user: dict = Depends(get_current_user)): def get_llm_task_status(task_id: str, current_user: dict = Depends(get_current_user)):
"""获取LLM总结任务状态包括进度""" """获取LLM总结任务状态包括进度"""
try: try:
status = async_llm_service.get_task_status(task_id) status = async_meeting_service.get_task_status(task_id)
if status.get('status') == 'not_found': if status.get('status') == 'not_found':
return create_api_response(code="404", message="Task not found") return create_api_response(code="404", message="Task not found")
return create_api_response(code="200", message="Task status retrieved", data=status) return create_api_response(code="200", message="Task status retrieved", data=status)

View File

@ -0,0 +1,131 @@
"""
声纹采集API接口
"""
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
from typing import Optional
from pathlib import Path
import datetime
from app.models.models import VoiceprintStatus, VoiceprintTemplate
from app.core.auth import get_current_user
from app.core.response import create_api_response
from app.services.voiceprint_service import voiceprint_service
import app.core.config as config_module
router = APIRouter()
@router.get("/voiceprint/template", response_model=None)
def get_voiceprint_template(current_user: dict = Depends(get_current_user)):
"""
获取声纹采集朗读模板配置
权限需要登录
"""
try:
template_data = VoiceprintTemplate(**config_module.VOICEPRINT_CONFIG)
return create_api_response(code="200", message="获取朗读模板成功", data=template_data.dict())
except Exception as e:
return create_api_response(code="500", message=f"获取朗读模板失败: {str(e)}")
@router.get("/voiceprint/{user_id}", response_model=None)
def get_voiceprint_status(user_id: int, current_user: dict = Depends(get_current_user)):
"""
获取用户声纹采集状态
权限用户只能查询自己的声纹状态管理员可查询所有
"""
# 权限检查:只能查询自己的声纹,或者是管理员
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限查询其他用户的声纹状态")
try:
status_data = voiceprint_service.get_user_voiceprint_status(user_id)
return create_api_response(code="200", message="获取声纹状态成功", data=status_data)
except Exception as e:
return create_api_response(code="500", message=f"获取声纹状态失败: {str(e)}")
@router.post("/voiceprint/{user_id}", response_model=None)
async def upload_voiceprint(
user_id: int,
audio_file: UploadFile = File(...),
current_user: dict = Depends(get_current_user)
):
"""
上传声纹音频文件同步处理
权限用户只能上传自己的声纹管理员可操作所有
"""
# 权限检查
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限上传其他用户的声纹")
# 检查文件格式
file_ext = Path(audio_file.filename).suffix.lower()
if file_ext not in config_module.ALLOWED_VOICEPRINT_EXTENSIONS:
return create_api_response(
code="400",
message=f"不支持的文件格式,仅支持: {', '.join(config_module.ALLOWED_VOICEPRINT_EXTENSIONS)}"
)
# 检查文件大小
max_size = config_module.VOICEPRINT_CONFIG.get('max_file_size', 5242880) # 默认5MB
content = await audio_file.read()
file_size = len(content)
if file_size > max_size:
return create_api_response(
code="400",
message=f"文件过大,最大允许 {max_size / 1024 / 1024:.1f}MB"
)
try:
# 确保用户目录存在
user_voiceprint_dir = config_module.VOICEPRINT_DIR / str(user_id)
user_voiceprint_dir.mkdir(parents=True, exist_ok=True)
# 生成文件名:时间戳.wav
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{timestamp}.wav"
file_path = user_voiceprint_dir / filename
# 保存文件
with open(file_path, "wb") as f:
f.write(content)
# 调用服务处理声纹(提取特征向量,保存到数据库)
result = voiceprint_service.save_voiceprint(user_id, str(file_path), file_size)
return create_api_response(code="200", message="声纹采集成功", data=result)
except Exception as e:
# 如果出错,删除已上传的文件
if 'file_path' in locals() and Path(file_path).exists():
Path(file_path).unlink()
return create_api_response(code="500", message=f"声纹采集失败: {str(e)}")
@router.delete("/voiceprint/{user_id}", response_model=None)
def delete_voiceprint(user_id: int, current_user: dict = Depends(get_current_user)):
"""
删除用户声纹数据允许重新采集
权限用户只能删除自己的声纹管理员可操作所有
"""
# 权限检查
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限删除其他用户的声纹")
try:
success = voiceprint_service.delete_voiceprint(user_id)
if success:
return create_api_response(code="200", message="声纹删除成功")
else:
return create_api_response(code="404", message="未找到该用户的声纹数据")
except Exception as e:
return create_api_response(code="500", message=f"删除声纹失败: {str(e)}")

View File

@ -1,4 +1,5 @@
import os import os
import json
from pathlib import Path from pathlib import Path
# 基础路径配置 # 基础路径配置
@ -6,10 +7,12 @@ BASE_DIR = Path(__file__).parent.parent.parent
UPLOAD_DIR = BASE_DIR / "uploads" UPLOAD_DIR = BASE_DIR / "uploads"
AUDIO_DIR = UPLOAD_DIR / "audio" AUDIO_DIR = UPLOAD_DIR / "audio"
MARKDOWN_DIR = UPLOAD_DIR / "markdown" MARKDOWN_DIR = UPLOAD_DIR / "markdown"
VOICEPRINT_DIR = UPLOAD_DIR / "voiceprint"
# 文件上传配置 # 文件上传配置
ALLOWED_EXTENSIONS = {".mp3", ".wav", ".m4a", ".mpeg", ".mp4"} ALLOWED_EXTENSIONS = {".mp3", ".wav", ".m4a", ".mpeg", ".mp4"}
ALLOWED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp"} ALLOWED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
ALLOWED_VOICEPRINT_EXTENSIONS = {".wav"}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
@ -17,6 +20,7 @@ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
UPLOAD_DIR.mkdir(exist_ok=True) UPLOAD_DIR.mkdir(exist_ok=True)
AUDIO_DIR.mkdir(exist_ok=True) AUDIO_DIR.mkdir(exist_ok=True)
MARKDOWN_DIR.mkdir(exist_ok=True) MARKDOWN_DIR.mkdir(exist_ok=True)
VOICEPRINT_DIR.mkdir(exist_ok=True)
# 数据库配置 # 数据库配置
DATABASE_CONFIG = { DATABASE_CONFIG = {
@ -82,3 +86,12 @@ LLM_CONFIG = {
# 密码重置配置 # 密码重置配置
DEFAULT_RESET_PASSWORD = os.getenv('DEFAULT_RESET_PASSWORD', '111111') DEFAULT_RESET_PASSWORD = os.getenv('DEFAULT_RESET_PASSWORD', '111111')
# 加载系统配置文件
# 默认声纹配置
VOICEPRINT_CONFIG = {
"template_text": "我正在进行声纹采集,这段语音将用于身份识别和验证。\n\n声纹技术能够准确识别每个人独特的声音特征。",
"duration_seconds": 12,
"sample_rate": 16000,
"channels": 1
}

View File

@ -128,7 +128,7 @@ class KnowledgeBase(BaseModel):
is_shared: bool is_shared: bool
source_meeting_ids: Optional[str] = None source_meeting_ids: Optional[str] = None
user_prompt: Optional[str] = None user_prompt: Optional[str] = None
tags: Optional[List[Tag]] = [] tags: Union[Optional[str], Optional[List[Tag]]] = None # 支持字符串或Tag列表
created_at: datetime.datetime created_at: datetime.datetime
updated_at: datetime.datetime updated_at: datetime.datetime
source_meeting_count: Optional[int] = 0 source_meeting_count: Optional[int] = 0
@ -204,3 +204,26 @@ class UpdateClientDownloadRequest(BaseModel):
class ClientDownloadListResponse(BaseModel): class ClientDownloadListResponse(BaseModel):
clients: List[ClientDownload] clients: List[ClientDownload]
total: int total: int
# 声纹采集相关模型
class VoiceprintInfo(BaseModel):
vp_id: int
user_id: int
file_path: str
file_size: Optional[int] = None
duration_seconds: Optional[float] = None
collected_at: datetime.datetime
updated_at: datetime.datetime
class VoiceprintStatus(BaseModel):
has_voiceprint: bool
vp_id: Optional[int] = None
file_path: Optional[str] = None
duration_seconds: Optional[float] = None
collected_at: Optional[datetime.datetime] = None
class VoiceprintTemplate(BaseModel):
template_text: str
duration_seconds: int
sample_rate: int
channels: int

View File

@ -1,3 +1,7 @@
"""
异步知识库服务 - 处理知识库生成的异步任务
采用FastAPI BackgroundTasks模式
"""
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
@ -7,6 +11,7 @@ from app.core.database import get_db_connection
from app.services.llm_service import LLMService from app.services.llm_service import LLMService
class AsyncKnowledgeBaseService: class AsyncKnowledgeBaseService:
"""异步知识库服务类 - 处理知识库相关的异步任务"""
def __init__(self): def __init__(self):
from app.core.config import REDIS_CONFIG from app.core.config import REDIS_CONFIG
@ -16,12 +21,25 @@ class AsyncKnowledgeBaseService:
self.llm_service = LLMService() self.llm_service = LLMService()
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str: def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str:
"""
创建异步知识库生成任务
Args:
user_id: 用户ID
kb_id: 知识库ID
user_prompt: 用户提示词
source_meeting_ids: 源会议ID列表
cursor: 数据库游标可选
Returns:
str: 任务ID
"""
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# If a cursor is passed, use it directly to avoid creating a new transaction # If a cursor is passed, use it directly to avoid creating a new transaction
if cursor: if cursor:
query = """ query = """
INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at) INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at)
VALUES (%s, %s, %s, %s, NOW()) VALUES (%s, %s, %s, %s, NOW())
""" """
cursor.execute(query, (task_id, user_id, kb_id, user_prompt)) cursor.execute(query, (task_id, user_id, kb_id, user_prompt))
@ -42,13 +60,17 @@ class AsyncKnowledgeBaseService:
} }
self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data) self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data)
self.redis_client.expire(f"kb_task:{task_id}", 86400) self.redis_client.expire(f"kb_task:{task_id}", 86400)
print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}") print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}")
return task_id return task_id
def _process_task(self, task_id: str): def _process_task(self, task_id: str):
"""
处理单个异步任务的函数设计为由BackgroundTasks调用
"""
print(f"Background task started for knowledge base task: {task_id}") print(f"Background task started for knowledge base task: {task_id}")
try: try:
# 从Redis获取任务数据
task_data = self.redis_client.hgetall(f"kb_task:{task_id}") task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data: if not task_data:
print(f"Error: Task {task_id} not found in Redis for processing.") print(f"Error: Task {task_id} not found in Redis for processing.")
@ -57,99 +79,136 @@ class AsyncKnowledgeBaseService:
kb_id = int(task_data['kb_id']) kb_id = int(task_data['kb_id'])
user_prompt = task_data.get('user_prompt', '') user_prompt = task_data.get('user_prompt', '')
# 1. 更新状态为processing
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...") self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
with get_db_connection() as connection: # 2. 获取关联的会议总结
cursor = connection.cursor(dictionary=True) self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
source_text = self._get_meeting_summaries(kb_id)
# Get source meeting summaries # 3. 构建提示词
source_text = "" self._update_task_status_in_redis(task_id, 'processing', 30, message="准备AI提示词...")
cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,)) full_prompt = self._build_prompt(source_text, user_prompt)
kb_info = cursor.fetchone()
if kb_info and kb_info['source_meeting_ids']:
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
if meeting_ids:
summaries = []
for meeting_id in meeting_ids:
cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
summary = cursor.fetchone()
if summary and summary['summary']:
summaries.append(summary['summary'])
source_text = "\n\n---\n\n".join(summaries)
# Get system prompt # 4. 调用LLM API
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...") self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
system_prompt = self._get_knowledge_task_prompt(cursor) generated_content = self.llm_service._call_llm_api(full_prompt)
if not generated_content:
# Build final prompt raise Exception("LLM API调用失败或返回空内容")
final_prompt = f"{system_prompt}\n\n"
if source_text:
final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
final_prompt += f"用户要求:{user_prompt}"
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...") # 5. 保存结果到数据库
generated_content = self.llm_service._call_llm_api(final_prompt) self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
self._save_result_to_db(kb_id, generated_content)
if not generated_content: # 6. 任务完成
raise Exception("LLM API call failed or returned empty content") self._update_task_in_db(task_id, 'completed', 100)
self._update_task_status_in_redis(task_id, 'completed', 100)
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...") print(f"Task {task_id} completed successfully")
self._save_result_to_db(kb_id, generated_content, cursor)
self._update_task_in_db(task_id, 'completed', 100, cursor=cursor)
self._update_task_status_in_redis(task_id, 'completed', 100)
connection.commit()
print(f"Task {task_id} completed successfully")
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
print(f"Task {task_id} failed: {error_msg}") print(f"Task {task_id} failed: {error_msg}")
# Use a new connection for error logging to avoid issues with a potentially broken transaction # 更新失败状态
with get_db_connection() as err_conn: self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
err_cursor = err_conn.cursor()
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor)
err_conn.commit()
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg) self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
def _get_knowledge_task_prompt(self, cursor) -> str: # --- 知识库相关方法 ---
query = """
SELECT p.content def _get_meeting_summaries(self, kb_id: int) -> str:
FROM prompt_config pc
JOIN prompts p ON pc.prompt_id = p.id
WHERE pc.task_name = 'KNOWLEDGE_TASK'
""" """
cursor.execute(query) 从数据库获取知识库关联的会议总结
result = cursor.fetchone()
if result:
return result['content']
else:
# Fallback prompt
return "Please generate a knowledge base article based on the provided information."
Args:
kb_id: 知识库ID
Returns:
str: 拼接后的会议总结文本
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
def _save_result_to_db(self, kb_id: int, content: str, cursor): # 获取知识库的源会议ID列表
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s" cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
cursor.execute(query, (content, kb_id)) kb_info = cursor.fetchone()
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None): if not kb_info or not kb_info['source_meeting_ids']:
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s" return ""
cursor.execute(query, (status, progress, error_message, status, task_id))
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None): # 解析会议ID列表
update_data = { meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
'status': status, if not meeting_ids:
'progress': str(progress), return ""
'updated_at': datetime.now().isoformat()
} # 获取所有会议的总结
if message: update_data['message'] = message summaries = []
if error_message: update_data['error_message'] = error_message for meeting_id in meeting_ids:
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data) cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
summary = cursor.fetchone()
if summary and summary['summary']:
summaries.append(summary['summary'])
# 用分隔符拼接多个会议总结
return "\n\n---\n\n".join(summaries)
except Exception as e:
print(f"获取会议总结错误: {e}")
return ""
def _build_prompt(self, source_text: str, user_prompt: str) -> str:
"""
构建完整的提示词
使用数据库中配置的KNOWLEDGE_TASK提示词模板
Args:
source_text: 源会议总结文本
user_prompt: 用户自定义提示词
Returns:
str: 完整的提示词
"""
# 从数据库获取知识库任务的提示词模板
system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK')
prompt = f"{system_prompt}\n\n"
if source_text:
prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
prompt += f"用户要求:{user_prompt}"
return prompt
def _save_result_to_db(self, kb_id: int, content: str) -> Optional[int]:
"""
保存生成结果到数据库
Args:
kb_id: 知识库ID
content: 生成的内容
Returns:
Optional[int]: 知识库ID失败返回None
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
cursor.execute(query, (content, kb_id))
connection.commit()
print(f"成功保存知识库内容kb_id: {kb_id}")
return kb_id
except Exception as e:
print(f"保存知识库内容错误: {e}")
return None
# --- 状态查询和数据库操作方法 ---
def get_task_status(self, task_id: str) -> Dict[str, Any]: def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""获取任务状态 - 与 async_llm_service 保持一致""" """获取任务状态"""
try: try:
task_data = self.redis_client.hgetall(f"kb_task:{task_id}") task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data: if not task_data:
@ -170,6 +229,20 @@ class AsyncKnowledgeBaseService:
print(f"Error getting task status: {e}") print(f"Error getting task status: {e}")
return {'task_id': task_id, 'status': 'error', 'error_message': str(e)} return {'task_id': task_id, 'status': 'error', 'error_message': str(e)}
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
"""更新Redis中的任务状态"""
try:
update_data = {
'status': status,
'progress': str(progress),
'updated_at': datetime.now().isoformat()
}
if message: update_data['message'] = message
if error_message: update_data['error_message'] = error_message
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
except Exception as e:
print(f"Error updating task status in Redis: {e}")
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str): def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str):
"""保存任务到数据库""" """保存任务到数据库"""
try: try:
@ -182,6 +255,17 @@ class AsyncKnowledgeBaseService:
print(f"Error saving task to database: {e}") print(f"Error saving task to database: {e}")
raise raise
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None):
"""更新数据库中的任务状态"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
cursor.execute(query, (status, progress, error_message, status, task_id))
connection.commit()
except Exception as e:
print(f"Error updating task in database: {e}")
def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]: def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]:
"""从数据库获取任务信息""" """从数据库获取任务信息"""
try: try:
@ -198,4 +282,5 @@ class AsyncKnowledgeBaseService:
print(f"Error getting task from database: {e}") print(f"Error getting task from database: {e}")
return None return None
# 创建全局实例
async_kb_service = AsyncKnowledgeBaseService() async_kb_service = AsyncKnowledgeBaseService()

View File

@ -1,5 +1,5 @@
""" """
异步LLM服务 - 处理会议总结生成的异步任务 异步会议服务 - 处理会议总结生成的异步任务
采用FastAPI BackgroundTasks模式 采用FastAPI BackgroundTasks模式
""" """
import uuid import uuid
@ -12,30 +12,30 @@ from app.core.config import REDIS_CONFIG
from app.core.database import get_db_connection from app.core.database import get_db_connection
from app.services.llm_service import LLMService from app.services.llm_service import LLMService
class AsyncLLMService: class AsyncMeetingService:
"""异步LLM服务类 - 采用FastAPI BackgroundTasks模式""" """异步会议服务类 - 处理会议相关的异步任务"""
def __init__(self): def __init__(self):
# 确保redis客户端自动解码响应代码更简洁 # 确保redis客户端自动解码响应代码更简洁
if 'decode_responses' not in REDIS_CONFIG: if 'decode_responses' not in REDIS_CONFIG:
REDIS_CONFIG['decode_responses'] = True REDIS_CONFIG['decode_responses'] = True
self.redis_client = redis.Redis(**REDIS_CONFIG) self.redis_client = redis.Redis(**REDIS_CONFIG)
self.llm_service = LLMService() # 复用现有的同步LLM服务 self.llm_service = LLMService() # 复用现有的同步LLM服务
def start_summary_generation(self, meeting_id: int, user_prompt: str = "") -> str: def start_summary_generation(self, meeting_id: int, user_prompt: str = "") -> str:
""" """
创建异步总结任务任务的执行将由外部如API层的BackgroundTasks触发 创建异步总结任务任务的执行将由外部如API层的BackgroundTasks触发
Args: Args:
meeting_id: 会议ID meeting_id: 会议ID
user_prompt: 用户额外提示词 user_prompt: 用户额外提示词
Returns: Returns:
str: 任务ID str: 任务ID
""" """
try: try:
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# 在数据库中创建任务记录 # 在数据库中创建任务记录
self._save_task_to_db(task_id, meeting_id, user_prompt) self._save_task_to_db(task_id, meeting_id, user_prompt)
@ -52,10 +52,10 @@ class AsyncLLMService:
} }
self.redis_client.hset(f"llm_task:{task_id}", mapping=task_data) self.redis_client.hset(f"llm_task:{task_id}", mapping=task_data)
self.redis_client.expire(f"llm_task:{task_id}", 86400) self.redis_client.expire(f"llm_task:{task_id}", 86400)
print(f"LLM summary task created: {task_id} for meeting: {meeting_id}") print(f"Meeting summary task created: {task_id} for meeting: {meeting_id}")
return task_id return task_id
except Exception as e: except Exception as e:
print(f"Error starting summary generation: {e}") print(f"Error starting summary generation: {e}")
raise e raise e
@ -64,7 +64,7 @@ class AsyncLLMService:
""" """
处理单个异步任务的函数设计为由BackgroundTasks调用 处理单个异步任务的函数设计为由BackgroundTasks调用
""" """
print(f"Background task started for LLM task: {task_id}") print(f"Background task started for meeting summary task: {task_id}")
try: try:
# 从Redis获取任务数据 # 从Redis获取任务数据
task_data = self.redis_client.hgetall(f"llm_task:{task_id}") task_data = self.redis_client.hgetall(f"llm_task:{task_id}")
@ -80,13 +80,13 @@ class AsyncLLMService:
# 2. 获取会议转录内容 # 2. 获取会议转录内容
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取会议转录内容...") self._update_task_status_in_redis(task_id, 'processing', 30, message="获取会议转录内容...")
transcript_text = self.llm_service._get_meeting_transcript(meeting_id) transcript_text = self._get_meeting_transcript(meeting_id)
if not transcript_text: if not transcript_text:
raise Exception("无法获取会议转录内容") raise Exception("无法获取会议转录内容")
# 3. 构建提示词 # 3. 构建提示词
self._update_task_status_in_redis(task_id, 'processing', 40, message="准备AI提示词...") self._update_task_status_in_redis(task_id, 'processing', 40, message="准备AI提示词...")
full_prompt = self.llm_service._build_prompt(transcript_text, user_prompt) full_prompt = self._build_prompt(transcript_text, user_prompt)
# 4. 调用LLM API # 4. 调用LLM API
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在分析会议内容...") self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在分析会议内容...")
@ -96,7 +96,7 @@ class AsyncLLMService:
# 5. 保存结果到主表 # 5. 保存结果到主表
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存总结结果...") self._update_task_status_in_redis(task_id, 'processing', 95, message="保存总结结果...")
self.llm_service._save_summary_to_db(meeting_id, summary_content, user_prompt) self._save_summary_to_db(meeting_id, summary_content, user_prompt)
# 6. 任务完成 # 6. 任务完成
self._update_task_in_db(task_id, 'completed', 100, result=summary_content) self._update_task_in_db(task_id, 'completed', 100, result=summary_content)
@ -110,6 +110,78 @@ class AsyncLLMService:
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg) self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg) self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
# --- 会议相关方法 ---
def _get_meeting_transcript(self, meeting_id: int) -> str:
"""从数据库获取会议转录内容"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = """
SELECT speaker_tag, start_time_ms, end_time_ms, text_content
FROM transcript_segments
WHERE meeting_id = %s
ORDER BY start_time_ms
"""
cursor.execute(query, (meeting_id,))
segments = cursor.fetchall()
if not segments:
return ""
# 组装转录文本
transcript_lines = []
for speaker_tag, start_time, end_time, text in segments:
# 将毫秒转换为分:秒格式
start_min = start_time // 60000
start_sec = (start_time % 60000) // 1000
transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}")
return "\n".join(transcript_lines)
except Exception as e:
print(f"获取会议转录内容错误: {e}")
return ""
def _build_prompt(self, transcript_text: str, user_prompt: str) -> str:
"""
构建完整的提示词
使用数据库中配置的MEETING_TASK提示词模板
"""
# 从数据库获取会议任务的提示词模板
system_prompt = self.llm_service.get_task_prompt('MEETING_TASK')
prompt = f"{system_prompt}\n\n"
if user_prompt:
prompt += f"用户额外要求:{user_prompt}\n\n"
prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:"
return prompt
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]:
"""保存总结到数据库 - 更新meetings表的summary、user_prompt和updated_at字段"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
# 更新meetings表的summary、user_prompt和updated_at字段
update_query = """
UPDATE meetings
SET summary = %s, user_prompt = %s, updated_at = NOW()
WHERE meeting_id = %s
"""
cursor.execute(update_query, (summary_content, user_prompt, meeting_id))
connection.commit()
print(f"成功保存会议总结到meetings表meeting_id: {meeting_id}")
return meeting_id
except Exception as e:
print(f"保存总结到数据库错误: {e}")
return None
# --- 状态查询和数据库操作方法 --- # --- 状态查询和数据库操作方法 ---
def get_task_status(self, task_id: str) -> Dict[str, Any]: def get_task_status(self, task_id: str) -> Dict[str, Any]:
@ -120,7 +192,7 @@ class AsyncLLMService:
task_data = self._get_task_from_db(task_id) task_data = self._get_task_from_db(task_id)
if not task_data: if not task_data:
return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'} return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'}
return { return {
'task_id': task_id, 'task_id': task_id,
'status': task_data.get('status', 'unknown'), 'status': task_data.get('status', 'unknown'),
@ -189,7 +261,7 @@ class AsyncLLMService:
params.insert(2, result) params.insert(2, result)
else: else:
query = "UPDATE llm_tasks SET status = %s, progress = %s, error_message = %s WHERE task_id = %s" query = "UPDATE llm_tasks SET status = %s, progress = %s, error_message = %s WHERE task_id = %s"
cursor.execute(query, tuple(params)) cursor.execute(query, tuple(params))
connection.commit() connection.commit()
except Exception as e: except Exception as e:
@ -212,4 +284,4 @@ class AsyncLLMService:
return None return None
# 创建全局实例 # 创建全局实例
async_llm_service = AsyncLLMService() async_meeting_service = AsyncMeetingService()

View File

@ -7,6 +7,8 @@ from app.core.database import get_db_connection
class LLMService: class LLMService:
"""LLM服务 - 专注于大模型API调用和提示词管理"""
def __init__(self): def __init__(self):
# 设置dashscope API key # 设置dashscope API key
dashscope.api_key = config_module.QWEN_API_KEY dashscope.api_key = config_module.QWEN_API_KEY
@ -35,125 +37,49 @@ class LLMService:
def top_p(self): def top_p(self):
"""动态获取top_p""" """动态获取top_p"""
return config_module.LLM_CONFIG["top_p"] return config_module.LLM_CONFIG["top_p"]
def generate_meeting_summary_stream(self, meeting_id: int, user_prompt: str = "") -> Generator[str, None, None]: def get_task_prompt(self, task_name: str, cursor=None) -> str:
""" """
流式生成会议总结 统一的提示词获取方法
Args: Args:
meeting_id: 会议ID task_name: 任务名称 'MEETING_TASK', 'KNOWLEDGE_TASK'
user_prompt: 用户额外提示词 cursor: 数据库游标如果传入则使用否则创建新连接
Yields:
str: 流式输出的内容片段
"""
try:
# 获取会议转录内容
transcript_text = self._get_meeting_transcript(meeting_id)
if not transcript_text:
yield "error: 无法获取会议转录内容"
return
# 构建完整提示词
full_prompt = self._build_prompt(transcript_text, user_prompt)
# 调用大模型API进行流式生成
full_content = ""
for chunk in self._call_llm_api_stream(full_prompt):
if chunk.startswith("error:"):
yield chunk
return
full_content += chunk
yield chunk
# 保存完整总结到数据库
if full_content:
self._save_summary_to_db(meeting_id, full_content, user_prompt)
except Exception as e:
print(f"流式生成会议总结错误: {e}")
yield f"error: {str(e)}"
def generate_meeting_summary(self, meeting_id: int, user_prompt: str = "") -> Optional[Dict]:
"""
生成会议总结非流式保持向后兼容
Args:
meeting_id: 会议ID
user_prompt: 用户额外提示词
Returns: Returns:
包含总结内容的字典如果失败返回None str: 提示词内容如果未找到返回默认提示词
"""
query = """
SELECT p.content
FROM prompt_config pc
JOIN prompts p ON pc.prompt_id = p.id
WHERE pc.task_name = %s
""" """
try:
# 获取会议转录内容
transcript_text = self._get_meeting_transcript(meeting_id)
if not transcript_text:
return {"error": "无法获取会议转录内容"}
# 构建完整提示词 if cursor:
full_prompt = self._build_prompt(transcript_text, user_prompt) cursor.execute(query, (task_name,))
result = cursor.fetchone()
# 调用大模型API if result:
response = self._call_llm_api(full_prompt) return result['content'] if isinstance(result, dict) else result[0]
else:
if response:
# 保存总结到数据库
summary_id = self._save_summary_to_db(meeting_id, response, user_prompt)
return {
"summary_id": summary_id,
"content": response,
"meeting_id": meeting_id
}
else:
return {"error": "大模型API调用失败"}
except Exception as e:
print(f"生成会议总结错误: {e}")
return {"error": str(e)}
def _get_meeting_transcript(self, meeting_id: int) -> str:
"""从数据库获取会议转录内容"""
try:
with get_db_connection() as connection: with get_db_connection() as connection:
cursor = connection.cursor() cursor = connection.cursor(dictionary=True)
query = """ cursor.execute(query, (task_name,))
SELECT speaker_tag, start_time_ms, end_time_ms, text_content result = cursor.fetchone()
FROM transcript_segments if result:
WHERE meeting_id = %s return result['content']
ORDER BY start_time_ms
""" # 返回默认提示词
cursor.execute(query, (meeting_id,)) return self._get_default_prompt(task_name)
segments = cursor.fetchall()
def _get_default_prompt(self, task_name: str) -> str:
if not segments: """获取默认提示词"""
return "" default_prompts = {
'MEETING_TASK': self.system_prompt, # 使用配置文件中的系统提示词
# 组装转录文本 'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。",
transcript_lines = [] }
for speaker_tag, start_time, end_time, text in segments: return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。")
# 将毫秒转换为分:秒格式
start_min = start_time // 60000
start_sec = (start_time % 60000) // 1000
transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}")
return "\n".join(transcript_lines)
except Exception as e:
print(f"获取会议转录内容错误: {e}")
return ""
def _build_prompt(self, transcript_text: str, user_prompt: str) -> str:
"""构建完整的提示词"""
prompt = f"{self.system_prompt}\n\n"
if user_prompt:
prompt += f"用户额外要求:{user_prompt}\n\n"
prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:"
return prompt
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]: def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
"""流式调用阿里Qwen3大模型API""" """流式调用阿里Qwen3大模型API"""
try: try:
@ -185,7 +111,7 @@ class LLMService:
yield f"error: {error_msg}" yield f"error: {error_msg}"
def _call_llm_api(self, prompt: str) -> Optional[str]: def _call_llm_api(self, prompt: str) -> Optional[str]:
"""调用阿里Qwen3大模型API非流式,保持向后兼容""" """调用阿里Qwen3大模型API非流式"""
try: try:
response = dashscope.Generation.call( response = dashscope.Generation.call(
model=self.model_name, model=self.model_name,
@ -204,96 +130,18 @@ class LLMService:
except Exception as e: except Exception as e:
print(f"调用大模型API错误: {e}") print(f"调用大模型API错误: {e}")
return None return None
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]:
"""保存总结到数据库 - 更新meetings表的summary字段"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
# 更新meetings表的summary字段
update_query = """
UPDATE meetings
SET summary = %s
WHERE meeting_id = %s
"""
cursor.execute(update_query, (summary_content, meeting_id))
connection.commit()
print(f"成功保存会议总结到meetings表meeting_id: {meeting_id}")
return meeting_id
except Exception as e:
print(f"保存总结到数据库错误: {e}")
return None
def get_meeting_summaries(self, meeting_id: int) -> List[Dict]:
"""获取会议的当前总结 - 从meetings表的summary字段获取"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = """
SELECT summary
FROM meetings
WHERE meeting_id = %s
"""
cursor.execute(query, (meeting_id,))
result = cursor.fetchone()
# 如果有总结内容返回一个包含当前总结的列表格式保持API一致性
if result and result[0]:
return [{
"id": meeting_id,
"content": result[0],
"user_prompt": "", # meetings表中没有user_prompt字段
"created_at": None # meetings表中没有单独的总结创建时间
}]
else:
return []
except Exception as e:
print(f"获取会议总结错误: {e}")
return []
def get_current_meeting_summary(self, meeting_id: int) -> Optional[str]:
"""获取会议当前的总结内容 - 从meetings表的summary字段获取"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = """
SELECT summary
FROM meetings
WHERE meeting_id = %s
"""
cursor.execute(query, (meeting_id,))
result = cursor.fetchone()
return result[0] if result and result[0] else None
except Exception as e:
print(f"获取会议当前总结错误: {e}")
return None
# 测试代码 # 测试代码
if __name__ == '__main__': if __name__ == '__main__':
# 测试LLM服务
test_meeting_id = 38
test_user_prompt = "请重点关注决策事项和待办任务"
print("--- 运行LLM服务测试 ---") print("--- 运行LLM服务测试 ---")
llm_service = LLMService() llm_service = LLMService()
# 生成总结 # 测试获取任务提示词
result = llm_service.generate_meeting_summary(test_meeting_id, test_user_prompt) meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
if result.get("error"): print(f"会议任务提示词: {meeting_prompt[:100]}...")
print(f"生成总结失败: {result['error']}")
else: knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
print(f"总结生成成功ID: {result.get('summary_id')}") print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
print(f"总结内容: {result.get('content')[:200]}...")
print("--- LLM服务测试完成 ---")
# 获取历史总结
summaries = llm_service.get_meeting_summaries(test_meeting_id)
print(f"获取到 {len(summaries)} 个历史总结")
print("--- LLM服务测试完成 ---")

View File

@ -0,0 +1,218 @@
"""
声纹服务 - 处理用户声纹采集存储和验证
"""
import os
import json
import wave
from datetime import datetime
from typing import Optional, Dict
from pathlib import Path
from app.core.database import get_db_connection
import app.core.config as config_module
class VoiceprintService:
"""声纹服务类 - 同步处理声纹采集"""
def __init__(self):
self.voiceprint_dir = config_module.VOICEPRINT_DIR
self.voiceprint_config = config_module.VOICEPRINT_CONFIG
def get_user_voiceprint_status(self, user_id: int) -> Dict:
"""
获取用户声纹状态
Args:
user_id: 用户ID
Returns:
Dict: 声纹状态信息
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = """
SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at
FROM user_voiceprint
WHERE user_id = %s
"""
cursor.execute(query, (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
return {
"has_voiceprint": True,
"vp_id": voiceprint['vp_id'],
"file_path": voiceprint['file_path'],
"duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None,
"collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None
}
else:
return {
"has_voiceprint": False,
"vp_id": None,
"file_path": None,
"duration_seconds": None,
"collected_at": None
}
except Exception as e:
print(f"获取声纹状态错误: {e}")
raise e
def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict:
"""
保存声纹文件并提取特征向量
Args:
user_id: 用户ID
audio_file_path: 音频文件路径
file_size: 文件大小
Returns:
Dict: 保存结果
"""
try:
# 1. 获取音频时长
duration = self._get_audio_duration(audio_file_path)
# 2. 提取声纹向量调用FunASR
vector_data = self._extract_voiceprint_vector(audio_file_path)
# 3. 保存到数据库
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 检查用户是否已有声纹
cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,))
existing = cursor.fetchone()
# 计算相对路径
relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR))
if existing:
# 更新现有记录
update_query = """
UPDATE user_voiceprint
SET file_path = %s, file_size = %s, duration_seconds = %s,
vector_data = %s, updated_at = NOW()
WHERE user_id = %s
"""
cursor.execute(update_query, (
relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None,
user_id
))
vp_id = existing['vp_id']
else:
# 插入新记录
insert_query = """
INSERT INTO user_voiceprint
(user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at)
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
"""
cursor.execute(insert_query, (
user_id, relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None
))
vp_id = cursor.lastrowid
connection.commit()
return {
"vp_id": vp_id,
"user_id": user_id,
"file_path": relative_path,
"file_size": file_size,
"duration_seconds": duration,
"has_vector": vector_data is not None
}
except Exception as e:
print(f"保存声纹错误: {e}")
raise e
def delete_voiceprint(self, user_id: int) -> bool:
"""
删除用户声纹
Args:
user_id: 用户ID
Returns:
bool: 是否删除成功
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取文件路径
cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
# 构建完整文件路径
relative_path = voiceprint['file_path']
if relative_path.startswith('/'):
relative_path = relative_path.lstrip('/')
file_path = config_module.BASE_DIR / relative_path
# 删除数据库记录
cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,))
connection.commit()
# 删除文件
if file_path.exists():
os.remove(file_path)
return True
else:
return False
except Exception as e:
print(f"删除声纹错误: {e}")
raise e
def _get_audio_duration(self, audio_file_path: str) -> float:
"""
获取音频文件时长
Args:
audio_file_path: 音频文件路径
Returns:
float: 时长
"""
try:
with wave.open(audio_file_path, 'rb') as wav_file:
frames = wav_file.getnframes()
rate = wav_file.getframerate()
duration = frames / float(rate)
return round(duration, 2)
except Exception as e:
print(f"获取音频时长错误: {e}")
return 10.0 # 默认返回10秒
def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]:
"""
提取声纹特征向量调用FunASR
Args:
audio_file_path: 音频文件路径
Returns:
Optional[list]: 声纹向量192失败返回None
"""
# TODO: 集成FunASR的说话人识别模型
# 使用 speech_campplus_sv_zh-cn_16k-common 模型
# 返回192维的embedding向量
print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}")
# 暂时返回None等待FunASR集成
# 集成后应该返回类似: [0.123, -0.456, 0.789, ...]
return None
# 创建全局实例
voiceprint_service = VoiceprintService()

View File

@ -1,7 +1,7 @@
{ {
"model_name": "qwen-plus", "model_name": "qwen-plus",
"system_prompt": "你是一个专业的会议记录分析助手。请根据提供的会议转录内容,生成简洁明了的会议总结。\n\n总结包括五个部分名称严格一致生成为MD二级目录\n1. 会议概述 - 简要说明会议的主要目的和背景(生成MD引用)\n2. 主要讨论点 - 列出会议中讨论的重要话题和内容\n3. 决策事项 - 明确记录会议中做出的决定和结论\n4. 待办事项 - 列出需要后续跟进的任务和责任人\n5. 关键信息 - 其他重要的信息点\n\n输出要求\n- 保持客观中性,不添加个人观点\n- 使用简洁、准确的中文表达\n- 按重要性排序各项内容\n- 如果某个部分没有相关内容,可以说明\"无相关内容\"\n- 总字数控制在500字以内", "system_prompt": "你是一个专业的会议记录分析助手。请根据提供的会议转录内容,生成简洁明了的会议总结。\n\n总结包括五个部分名称严格一致生成为MD二级目录\n1. 会议概述 - 简要说明会议的主要目的和背景(生成MD引用)\n2. 主要讨论点 - 列出会议中讨论的重要话题和内容\n3. 决策事项 - 明确记录会议中做出的决定和结论\n4. 待办事项 - 列出需要后续跟进的任务和责任人\n5. 关键信息 - 其他重要的信息点\n\n输出要求\n- 保持客观中性,不添加个人观点\n- 使用简洁、准确的中文表达\n- 按重要性排序各项内容\n- 如果某个部分没有相关内容,可以说明\"无相关内容\"\n- 总字数控制在500字以内",
"DEFAULT_RESET_PASSWORD": "111111", "DEFAULT_RESET_PASSWORD": "123456",
"MAX_FILE_SIZE": 209715200, "MAX_FILE_SIZE": 208666624,
"MAX_IMAGE_SIZE": 10485760 "MAX_IMAGE_SIZE": 10485760
} }

View File

@ -2,7 +2,7 @@ import uvicorn
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads, voiceprint
from app.core.config import UPLOAD_DIR, API_CONFIG from app.core.config import UPLOAD_DIR, API_CONFIG
from app.api.endpoints.admin import load_system_config from app.api.endpoints.admin import load_system_config
import os import os
@ -39,6 +39,7 @@ app.include_router(tasks.router, prefix="/api", tags=["Tasks"])
app.include_router(prompts.router, prefix="/api", tags=["Prompts"]) app.include_router(prompts.router, prefix="/api", tags=["Prompts"])
app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"]) app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"])
app.include_router(client_downloads.router, prefix="/api/clients", tags=["ClientDownloads"]) app.include_router(client_downloads.router, prefix="/api/clients", tags=["ClientDownloads"])
app.include_router(voiceprint.router, prefix="/api", tags=["Voiceprint"])
@app.get("/") @app.get("/")
def read_root(): def read_root():

View File

@ -6,10 +6,10 @@ import sys
import os import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.services.async_llm_service import AsyncLLMService from app.services.async_meeting_service import AsyncMeetingService
# 创建服务实例 # 创建服务实例
service = AsyncLLMService() service = AsyncMeetingService()
# 创建测试任务 # 创建测试任务
meeting_id = 38 meeting_id = 38

View File

@ -8,10 +8,10 @@ import time
import threading import threading
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.services.async_llm_service import AsyncLLMService from app.services.async_meeting_service import AsyncMeetingService
# 创建服务实例 # 创建服务实例
service = AsyncLLMService() service = AsyncMeetingService()
# 直接调用处理任务方法测试 # 直接调用处理任务方法测试
print("测试直接调用_process_tasks方法...") print("测试直接调用_process_tasks方法...")

View File

@ -0,0 +1,141 @@
"""
声纹采集API测试脚本
使用方法
1. 确保后端服务正在运行
2. 修改 USER_ID TOKEN 为实际值
3. 准备一个10秒的WAV音频文件
4. 运行: python test_voiceprint_api.py
"""
import requests
import json
# 配置
BASE_URL = "http://localhost:8000/api"
USER_ID = 1 # 修改为实际用户ID
TOKEN = "" # 登录后获取的token
# 请求头
headers = {
"Authorization": f"Bearer {TOKEN}",
"Content-Type": "application/json"
}
def test_get_template():
"""测试获取朗读模板"""
print("\n=== 测试1: 获取朗读模板 ===")
url = f"{BASE_URL}/voiceprint/template"
response = requests.get(url, headers=headers)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def test_get_status(user_id):
"""测试获取声纹状态"""
print(f"\n=== 测试2: 获取用户 {user_id} 的声纹状态 ===")
url = f"{BASE_URL}/voiceprint/{user_id}"
response = requests.get(url, headers=headers)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def test_upload_voiceprint(user_id, audio_file_path):
"""测试上传声纹"""
print(f"\n=== 测试3: 上传声纹音频 ===")
url = f"{BASE_URL}/voiceprint/{user_id}"
# 移除Content-Type让requests自动设置multipart/form-data
upload_headers = {
"Authorization": f"Bearer {TOKEN}"
}
with open(audio_file_path, 'rb') as f:
files = {'audio_file': (audio_file_path.split('/')[-1], f, 'audio/wav')}
response = requests.post(url, headers=upload_headers, files=files)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def test_delete_voiceprint(user_id):
"""测试删除声纹"""
print(f"\n=== 测试4: 删除用户 {user_id} 的声纹 ===")
url = f"{BASE_URL}/voiceprint/{user_id}"
response = requests.delete(url, headers=headers)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def login(username, password):
"""登录获取token"""
print("\n=== 登录获取Token ===")
url = f"{BASE_URL}/auth/login"
data = {
"username": username,
"password": password
}
response = requests.post(url, json=data)
if response.status_code == 200:
result = response.json()
if result.get('code') == '200':
token = result['data']['token']
print(f"登录成功Token: {token[:20]}...")
return token
else:
print(f"登录失败: {result.get('message')}")
return None
else:
print(f"请求失败,状态码: {response.status_code}")
return None
if __name__ == "__main__":
print("=" * 60)
print("声纹采集API测试脚本")
print("=" * 60)
# 步骤1: 登录如果没有token
if not TOKEN:
print("\n请先登录获取Token...")
username = input("用户名: ")
password = input("密码: ")
TOKEN = login(username, password)
if TOKEN:
headers["Authorization"] = f"Bearer {TOKEN}"
else:
print("登录失败,退出测试")
exit(1)
# 步骤2: 测试获取朗读模板
test_get_template()
# 步骤3: 测试获取声纹状态
test_get_status(USER_ID)
# 步骤4: 测试上传声纹(需要准备音频文件)
audio_file = input("\n请输入WAV音频文件路径 (回车跳过上传测试): ")
if audio_file.strip():
test_upload_voiceprint(USER_ID, audio_file.strip())
# 上传后再次查看状态
print("\n=== 上传后再次查看状态 ===")
test_get_status(USER_ID)
# 步骤5: 测试删除声纹
confirm = input("\n是否测试删除声纹? (yes/no): ")
if confirm.lower() == 'yes':
test_delete_voiceprint(USER_ID)
# 删除后再次查看状态
print("\n=== 删除后再次查看状态 ===")
test_get_status(USER_ID)
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)