Compare commits
2 Commits
5bdab4a405
...
2f36474f4d
| Author | SHA1 | Date |
|---|---|---|
|
|
2f36474f4d | |
|
|
976ea854b6 |
|
|
@ -9,29 +9,25 @@ import datetime
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
|
||||
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
|
||||
"""
|
||||
处理标签:查询已存在的标签,如果提供了 creator_id 则创建不存在的标签
|
||||
"""
|
||||
if not tag_string:
|
||||
return []
|
||||
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
|
||||
if not tag_names:
|
||||
return []
|
||||
|
||||
placeholders = ','.join(['%s'] * len(tag_names))
|
||||
select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})"
|
||||
cursor.execute(select_query, tuple(tag_names))
|
||||
# 如果提供了 creator_id,则创建不存在的标签
|
||||
if creator_id:
|
||||
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
|
||||
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
|
||||
|
||||
# 查询所有标签信息
|
||||
format_strings = ', '.join(['%s'] * len(tag_names))
|
||||
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
|
||||
tags_data = cursor.fetchall()
|
||||
|
||||
existing_tags = {tag['name']: tag for tag in tags_data}
|
||||
new_tags = [name for name in tag_names if name not in existing_tags]
|
||||
|
||||
if new_tags:
|
||||
insert_query = "INSERT INTO tags (name) VALUES (%s)"
|
||||
cursor.executemany(insert_query, [(name,) for name in new_tags])
|
||||
|
||||
# Re-fetch all tags to get their IDs and default colors
|
||||
cursor.execute(select_query, tuple(tag_names))
|
||||
tags_data = cursor.fetchall()
|
||||
|
||||
return [Tag(**tag) for tag in tags_data]
|
||||
|
||||
|
||||
|
|
@ -83,7 +79,8 @@ def get_knowledge_bases(
|
|||
|
||||
kb_list = []
|
||||
for kb_data in kbs_data:
|
||||
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
|
||||
# 列表页不需要处理 tags,直接使用字符串
|
||||
# kb_data['tags'] 保持原样(逗号分隔的标签名称字符串)
|
||||
# Count source meetings - filter empty strings
|
||||
if kb_data.get('source_meeting_ids'):
|
||||
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
|
||||
|
|
@ -122,7 +119,7 @@ def create_knowledge_base(
|
|||
request.is_shared,
|
||||
request.source_meeting_ids,
|
||||
request.user_prompt,
|
||||
request.tags,
|
||||
request.tags, # 创建时 tags 应该为 None 或空字符串
|
||||
now,
|
||||
now
|
||||
))
|
||||
|
|
@ -171,7 +168,8 @@ def get_knowledge_base_detail(
|
|||
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Process tags
|
||||
# Process tags - 获取标签的完整信息(包括颜色)
|
||||
# 详情页不需要创建新标签,所以不传 creator_id
|
||||
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
|
||||
|
||||
# Get source meetings details
|
||||
|
|
@ -220,6 +218,10 @@ def update_knowledge_base(
|
|||
if kb['creator_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
|
||||
|
||||
# 使用 _process_tags 处理标签(会自动创建新标签)
|
||||
if request.tags:
|
||||
_process_tags(cursor, request.tags, current_user['user_id'])
|
||||
|
||||
# Update the knowledge base
|
||||
now = datetime.datetime.utcnow()
|
||||
update_query = """
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from app.core.config import BASE_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSION
|
|||
import app.core.config as config_module
|
||||
from app.services.llm_service import LLMService
|
||||
from app.services.async_transcription_service import AsyncTranscriptionService
|
||||
from app.services.async_llm_service import async_llm_service
|
||||
from app.services.async_meeting_service import async_meeting_service
|
||||
from app.core.auth import get_current_user
|
||||
from app.core.response import create_api_response
|
||||
from typing import List, Optional
|
||||
|
|
@ -23,14 +23,22 @@ transcription_service = AsyncTranscriptionService()
|
|||
class GenerateSummaryRequest(BaseModel):
|
||||
user_prompt: Optional[str] = ""
|
||||
|
||||
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
|
||||
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
|
||||
"""
|
||||
处理标签:查询已存在的标签,如果提供了 creator_id 则创建不存在的标签
|
||||
"""
|
||||
if not tag_string:
|
||||
return []
|
||||
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
|
||||
if not tag_names:
|
||||
return []
|
||||
insert_ignore_query = "INSERT IGNORE INTO tags (name) VALUES (%s)"
|
||||
cursor.executemany(insert_ignore_query, [(name,) for name in tag_names])
|
||||
|
||||
# 如果提供了 creator_id,则创建不存在的标签
|
||||
if creator_id:
|
||||
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
|
||||
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
|
||||
|
||||
# 查询所有标签信息
|
||||
format_strings = ', '.join(['%s'] * len(tag_names))
|
||||
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
|
||||
tags_data = cursor.fetchall()
|
||||
|
|
@ -125,10 +133,9 @@ def get_meeting_transcript(meeting_id: int, current_user: dict = Depends(get_cur
|
|||
def create_meeting(meeting_request: CreateMeetingRequest, current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
# 使用 _process_tags 来处理标签创建
|
||||
if meeting_request.tags:
|
||||
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()]
|
||||
if tag_names:
|
||||
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
|
||||
_process_tags(cursor, meeting_request.tags, current_user['user_id'])
|
||||
meeting_query = 'INSERT INTO meetings (user_id, title, meeting_time, summary, tags, created_at) VALUES (%s, %s, %s, %s, %s, %s)'
|
||||
cursor.execute(meeting_query, (meeting_request.user_id, meeting_request.title, meeting_request.meeting_time, None, meeting_request.tags, datetime.now().isoformat()))
|
||||
meeting_id = cursor.lastrowid
|
||||
|
|
@ -147,10 +154,9 @@ def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest, curre
|
|||
return create_api_response(code="404", message="Meeting not found")
|
||||
if meeting['user_id'] != current_user['user_id']:
|
||||
return create_api_response(code="403", message="Permission denied")
|
||||
# 使用 _process_tags 来处理标签创建
|
||||
if meeting_request.tags:
|
||||
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()]
|
||||
if tag_names:
|
||||
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
|
||||
_process_tags(cursor, meeting_request.tags, current_user['user_id'])
|
||||
update_query = 'UPDATE meetings SET title = %s, meeting_time = %s, summary = %s, tags = %s WHERE meeting_id = %s'
|
||||
cursor.execute(update_query, (meeting_request.title, meeting_request.meeting_time, meeting_request.summary, meeting_request.tags, meeting_id))
|
||||
cursor.execute("DELETE FROM attendees WHERE meeting_id = %s", (meeting_id,))
|
||||
|
|
@ -449,8 +455,8 @@ def generate_meeting_summary_async(meeting_id: int, request: GenerateSummaryRequ
|
|||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
return create_api_response(code="404", message="Meeting not found")
|
||||
task_id = async_llm_service.start_summary_generation(meeting_id, request.user_prompt)
|
||||
background_tasks.add_task(async_llm_service._process_task, task_id)
|
||||
task_id = async_meeting_service.start_summary_generation(meeting_id, request.user_prompt)
|
||||
background_tasks.add_task(async_meeting_service._process_task, task_id)
|
||||
return create_api_response(code="200", message="Summary generation task has been accepted.", data={
|
||||
"task_id": task_id, "status": "pending", "meeting_id": meeting_id
|
||||
})
|
||||
|
|
@ -465,7 +471,7 @@ def get_meeting_llm_tasks(meeting_id: int, current_user: dict = Depends(get_curr
|
|||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
return create_api_response(code="404", message="Meeting not found")
|
||||
tasks = async_llm_service.get_meeting_llm_tasks(meeting_id)
|
||||
tasks = async_meeting_service.get_meeting_llm_tasks(meeting_id)
|
||||
return create_api_response(code="200", message="LLM tasks retrieved successfully", data={
|
||||
"tasks": tasks, "total": len(tasks)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.auth import get_current_admin_user
|
||||
from app.core.auth import get_current_user
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.response import create_api_response
|
||||
|
||||
|
|
@ -23,14 +23,14 @@ class PromptListResponse(BaseModel):
|
|||
total: int
|
||||
|
||||
@router.post("/prompts")
|
||||
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)):
|
||||
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
||||
"""Create a new prompt."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
try:
|
||||
cursor.execute(
|
||||
"INSERT INTO prompts (name, tags, content) VALUES (%s, %s, %s)",
|
||||
(prompt.name, prompt.tags, prompt.content)
|
||||
"INSERT INTO prompts (name, tags, content, creator_id) VALUES (%s, %s, %s, %s)",
|
||||
(prompt.name, prompt.tags, prompt.content, current_user["user_id"])
|
||||
)
|
||||
connection.commit()
|
||||
new_id = cursor.lastrowid
|
||||
|
|
@ -41,23 +41,27 @@ def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_adm
|
|||
return create_api_response(code="500", message=f"创建提示词失败: {e}")
|
||||
|
||||
@router.get("/prompts")
|
||||
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_admin_user)):
|
||||
"""Get a paginated list of prompts."""
|
||||
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_user)):
|
||||
"""Get a paginated list of prompts filtered by current user."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute("SELECT COUNT(*) as total FROM prompts")
|
||||
# 只获取当前用户创建的提示词
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) as total FROM prompts WHERE creator_id = %s",
|
||||
(current_user["user_id"],)
|
||||
)
|
||||
total = cursor.fetchone()['total']
|
||||
|
||||
offset = (page - 1) * size
|
||||
cursor.execute(
|
||||
"SELECT id, name, tags, content, created_at FROM prompts ORDER BY created_at DESC LIMIT %s OFFSET %s",
|
||||
(size, offset)
|
||||
"SELECT id, name, tags, content, created_at FROM prompts WHERE creator_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s",
|
||||
(current_user["user_id"], size, offset)
|
||||
)
|
||||
prompts = cursor.fetchall()
|
||||
return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total})
|
||||
|
||||
@router.get("/prompts/{prompt_id}")
|
||||
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)):
|
||||
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""Get a single prompt by its ID."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
|
@ -68,7 +72,7 @@ def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_us
|
|||
return create_api_response(code="200", message="获取提示词成功", data=prompt)
|
||||
|
||||
@router.put("/prompts/{prompt_id}")
|
||||
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)):
|
||||
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
||||
"""Update an existing prompt."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
|
@ -87,12 +91,24 @@ def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends
|
|||
return create_api_response(code="500", message=f"更新提示词失败: {e}")
|
||||
|
||||
@router.delete("/prompts/{prompt_id}")
|
||||
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)):
|
||||
"""Delete a prompt."""
|
||||
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""Delete a prompt. Only the creator can delete their own prompts."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
|
||||
if cursor.rowcount == 0:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
# 首先检查提示词是否存在以及是否属于当前用户
|
||||
cursor.execute(
|
||||
"SELECT creator_id FROM prompts WHERE id = %s",
|
||||
(prompt_id,)
|
||||
)
|
||||
prompt = cursor.fetchone()
|
||||
|
||||
if not prompt:
|
||||
return create_api_response(code="404", message="提示词不存在")
|
||||
|
||||
if prompt['creator_id'] != current_user["user_id"]:
|
||||
return create_api_response(code="403", message="无权删除其他用户的提示词")
|
||||
|
||||
# 删除提示词
|
||||
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
|
||||
connection.commit()
|
||||
return create_api_response(code="200", message="提示词删除成功")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.response import create_api_response
|
||||
from app.core.auth import get_current_user
|
||||
from app.models.models import Tag
|
||||
from typing import List
|
||||
import mysql.connector
|
||||
|
|
@ -24,16 +25,16 @@ def get_all_tags():
|
|||
return create_api_response(code="500", message="获取标签失败")
|
||||
|
||||
@router.post("/tags/")
|
||||
def create_tag(tag_in: Tag):
|
||||
def create_tag(tag_in: Tag, current_user: dict = Depends(get_current_user)):
|
||||
"""_summary_
|
||||
创建一个新标签
|
||||
创建一个新标签,并记录创建者
|
||||
"""
|
||||
query = "INSERT INTO tags (name, color) VALUES (%s, %s)"
|
||||
query = "INSERT INTO tags (name, color, creator_id) VALUES (%s, %s, %s)"
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
with connection.cursor(dictionary=True) as cursor:
|
||||
try:
|
||||
cursor.execute(query, (tag_in.name, tag_in.color))
|
||||
cursor.execute(query, (tag_in.name, tag_in.color, current_user["user_id"]))
|
||||
connection.commit()
|
||||
tag_id = cursor.lastrowid
|
||||
new_tag = {"id": tag_id, "name": tag_in.name, "color": tag_in.color}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
|
|||
from app.core.auth import get_current_user
|
||||
from app.core.response import create_api_response
|
||||
from app.services.async_transcription_service import AsyncTranscriptionService
|
||||
from app.services.async_llm_service import async_llm_service
|
||||
from app.services.async_meeting_service import async_meeting_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ def get_transcription_task_status(task_id: str, current_user: dict = Depends(get
|
|||
def get_llm_task_status(task_id: str, current_user: dict = Depends(get_current_user)):
|
||||
"""获取LLM总结任务状态(包括进度)"""
|
||||
try:
|
||||
status = async_llm_service.get_task_status(task_id)
|
||||
status = async_meeting_service.get_task_status(task_id)
|
||||
if status.get('status') == 'not_found':
|
||||
return create_api_response(code="404", message="Task not found")
|
||||
return create_api_response(code="200", message="Task status retrieved", data=status)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,131 @@
|
|||
"""
|
||||
声纹采集API接口
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
|
||||
from app.models.models import VoiceprintStatus, VoiceprintTemplate
|
||||
from app.core.auth import get_current_user
|
||||
from app.core.response import create_api_response
|
||||
from app.services.voiceprint_service import voiceprint_service
|
||||
import app.core.config as config_module
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/voiceprint/template", response_model=None)
|
||||
def get_voiceprint_template(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
获取声纹采集朗读模板配置
|
||||
|
||||
权限:需要登录
|
||||
"""
|
||||
try:
|
||||
template_data = VoiceprintTemplate(**config_module.VOICEPRINT_CONFIG)
|
||||
return create_api_response(code="200", message="获取朗读模板成功", data=template_data.dict())
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"获取朗读模板失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/voiceprint/{user_id}", response_model=None)
|
||||
def get_voiceprint_status(user_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
获取用户声纹采集状态
|
||||
|
||||
权限:用户只能查询自己的声纹状态,管理员可查询所有
|
||||
"""
|
||||
# 权限检查:只能查询自己的声纹,或者是管理员
|
||||
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
|
||||
return create_api_response(code="403", message="无权限查询其他用户的声纹状态")
|
||||
|
||||
try:
|
||||
status_data = voiceprint_service.get_user_voiceprint_status(user_id)
|
||||
return create_api_response(code="200", message="获取声纹状态成功", data=status_data)
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"获取声纹状态失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/voiceprint/{user_id}", response_model=None)
|
||||
async def upload_voiceprint(
|
||||
user_id: int,
|
||||
audio_file: UploadFile = File(...),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
上传声纹音频文件(同步处理)
|
||||
|
||||
权限:用户只能上传自己的声纹,管理员可操作所有
|
||||
"""
|
||||
# 权限检查
|
||||
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
|
||||
return create_api_response(code="403", message="无权限上传其他用户的声纹")
|
||||
|
||||
# 检查文件格式
|
||||
file_ext = Path(audio_file.filename).suffix.lower()
|
||||
if file_ext not in config_module.ALLOWED_VOICEPRINT_EXTENSIONS:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message=f"不支持的文件格式,仅支持: {', '.join(config_module.ALLOWED_VOICEPRINT_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
# 检查文件大小
|
||||
max_size = config_module.VOICEPRINT_CONFIG.get('max_file_size', 5242880) # 默认5MB
|
||||
content = await audio_file.read()
|
||||
file_size = len(content)
|
||||
|
||||
if file_size > max_size:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message=f"文件过大,最大允许 {max_size / 1024 / 1024:.1f}MB"
|
||||
)
|
||||
|
||||
try:
|
||||
# 确保用户目录存在
|
||||
user_voiceprint_dir = config_module.VOICEPRINT_DIR / str(user_id)
|
||||
user_voiceprint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成文件名:时间戳.wav
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{timestamp}.wav"
|
||||
file_path = user_voiceprint_dir / filename
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 调用服务处理声纹(提取特征向量,保存到数据库)
|
||||
result = voiceprint_service.save_voiceprint(user_id, str(file_path), file_size)
|
||||
|
||||
return create_api_response(code="200", message="声纹采集成功", data=result)
|
||||
|
||||
except Exception as e:
|
||||
# 如果出错,删除已上传的文件
|
||||
if 'file_path' in locals() and Path(file_path).exists():
|
||||
Path(file_path).unlink()
|
||||
|
||||
return create_api_response(code="500", message=f"声纹采集失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/voiceprint/{user_id}", response_model=None)
|
||||
def delete_voiceprint(user_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
删除用户声纹数据,允许重新采集
|
||||
|
||||
权限:用户只能删除自己的声纹,管理员可操作所有
|
||||
"""
|
||||
# 权限检查
|
||||
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
|
||||
return create_api_response(code="403", message="无权限删除其他用户的声纹")
|
||||
|
||||
try:
|
||||
success = voiceprint_service.delete_voiceprint(user_id)
|
||||
|
||||
if success:
|
||||
return create_api_response(code="200", message="声纹删除成功")
|
||||
else:
|
||||
return create_api_response(code="404", message="未找到该用户的声纹数据")
|
||||
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"删除声纹失败: {str(e)}")
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# 基础路径配置
|
||||
|
|
@ -6,10 +7,12 @@ BASE_DIR = Path(__file__).parent.parent.parent
|
|||
UPLOAD_DIR = BASE_DIR / "uploads"
|
||||
AUDIO_DIR = UPLOAD_DIR / "audio"
|
||||
MARKDOWN_DIR = UPLOAD_DIR / "markdown"
|
||||
VOICEPRINT_DIR = UPLOAD_DIR / "voiceprint"
|
||||
|
||||
# 文件上传配置
|
||||
ALLOWED_EXTENSIONS = {".mp3", ".wav", ".m4a", ".mpeg", ".mp4"}
|
||||
ALLOWED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
ALLOWED_VOICEPRINT_EXTENSIONS = {".wav"}
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
|
|
@ -17,6 +20,7 @@ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
|
|||
UPLOAD_DIR.mkdir(exist_ok=True)
|
||||
AUDIO_DIR.mkdir(exist_ok=True)
|
||||
MARKDOWN_DIR.mkdir(exist_ok=True)
|
||||
VOICEPRINT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_CONFIG = {
|
||||
|
|
@ -82,3 +86,12 @@ LLM_CONFIG = {
|
|||
|
||||
# 密码重置配置
|
||||
DEFAULT_RESET_PASSWORD = os.getenv('DEFAULT_RESET_PASSWORD', '111111')
|
||||
|
||||
# 加载系统配置文件
|
||||
# 默认声纹配置
|
||||
VOICEPRINT_CONFIG = {
|
||||
"template_text": "我正在进行声纹采集,这段语音将用于身份识别和验证。\n\n声纹技术能够准确识别每个人独特的声音特征。",
|
||||
"duration_seconds": 12,
|
||||
"sample_rate": 16000,
|
||||
"channels": 1
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class KnowledgeBase(BaseModel):
|
|||
is_shared: bool
|
||||
source_meeting_ids: Optional[str] = None
|
||||
user_prompt: Optional[str] = None
|
||||
tags: Optional[List[Tag]] = []
|
||||
tags: Union[Optional[str], Optional[List[Tag]]] = None # 支持字符串或Tag列表
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
source_meeting_count: Optional[int] = 0
|
||||
|
|
@ -204,3 +204,26 @@ class UpdateClientDownloadRequest(BaseModel):
|
|||
class ClientDownloadListResponse(BaseModel):
|
||||
clients: List[ClientDownload]
|
||||
total: int
|
||||
|
||||
# 声纹采集相关模型
|
||||
class VoiceprintInfo(BaseModel):
|
||||
vp_id: int
|
||||
user_id: int
|
||||
file_path: str
|
||||
file_size: Optional[int] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
collected_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
class VoiceprintStatus(BaseModel):
|
||||
has_voiceprint: bool
|
||||
vp_id: Optional[int] = None
|
||||
file_path: Optional[str] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
collected_at: Optional[datetime.datetime] = None
|
||||
|
||||
class VoiceprintTemplate(BaseModel):
|
||||
template_text: str
|
||||
duration_seconds: int
|
||||
sample_rate: int
|
||||
channels: int
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
"""
|
||||
异步知识库服务 - 处理知识库生成的异步任务
|
||||
采用FastAPI BackgroundTasks模式
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
|
@ -7,6 +11,7 @@ from app.core.database import get_db_connection
|
|||
from app.services.llm_service import LLMService
|
||||
|
||||
class AsyncKnowledgeBaseService:
|
||||
"""异步知识库服务类 - 处理知识库相关的异步任务"""
|
||||
|
||||
def __init__(self):
|
||||
from app.core.config import REDIS_CONFIG
|
||||
|
|
@ -16,6 +21,19 @@ class AsyncKnowledgeBaseService:
|
|||
self.llm_service = LLMService()
|
||||
|
||||
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str:
|
||||
"""
|
||||
创建异步知识库生成任务
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
kb_id: 知识库ID
|
||||
user_prompt: 用户提示词
|
||||
source_meeting_ids: 源会议ID列表
|
||||
cursor: 数据库游标(可选)
|
||||
|
||||
Returns:
|
||||
str: 任务ID
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# If a cursor is passed, use it directly to avoid creating a new transaction
|
||||
|
|
@ -47,8 +65,12 @@ class AsyncKnowledgeBaseService:
|
|||
return task_id
|
||||
|
||||
def _process_task(self, task_id: str):
|
||||
"""
|
||||
处理单个异步任务的函数,设计为由BackgroundTasks调用。
|
||||
"""
|
||||
print(f"Background task started for knowledge base task: {task_id}")
|
||||
try:
|
||||
# 从Redis获取任务数据
|
||||
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
|
||||
if not task_data:
|
||||
print(f"Error: Task {task_id} not found in Redis for processing.")
|
||||
|
|
@ -57,99 +79,136 @@ class AsyncKnowledgeBaseService:
|
|||
kb_id = int(task_data['kb_id'])
|
||||
user_prompt = task_data.get('user_prompt', '')
|
||||
|
||||
# 1. 更新状态为processing
|
||||
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
|
||||
|
||||
# 2. 获取关联的会议总结
|
||||
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
|
||||
source_text = self._get_meeting_summaries(kb_id)
|
||||
|
||||
# 3. 构建提示词
|
||||
self._update_task_status_in_redis(task_id, 'processing', 30, message="准备AI提示词...")
|
||||
full_prompt = self._build_prompt(source_text, user_prompt)
|
||||
|
||||
# 4. 调用LLM API
|
||||
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
|
||||
generated_content = self.llm_service._call_llm_api(full_prompt)
|
||||
if not generated_content:
|
||||
raise Exception("LLM API调用失败或返回空内容")
|
||||
|
||||
# 5. 保存结果到数据库
|
||||
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
|
||||
self._save_result_to_db(kb_id, generated_content)
|
||||
|
||||
# 6. 任务完成
|
||||
self._update_task_in_db(task_id, 'completed', 100)
|
||||
self._update_task_status_in_redis(task_id, 'completed', 100)
|
||||
|
||||
print(f"Task {task_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"Task {task_id} failed: {error_msg}")
|
||||
# 更新失败状态
|
||||
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
|
||||
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
|
||||
|
||||
# --- 知识库相关方法 ---
|
||||
|
||||
def _get_meeting_summaries(self, kb_id: int) -> str:
|
||||
"""
|
||||
从数据库获取知识库关联的会议总结
|
||||
|
||||
Args:
|
||||
kb_id: 知识库ID
|
||||
|
||||
Returns:
|
||||
str: 拼接后的会议总结文本
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Get source meeting summaries
|
||||
source_text = ""
|
||||
# 获取知识库的源会议ID列表
|
||||
cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
|
||||
kb_info = cursor.fetchone()
|
||||
if kb_info and kb_info['source_meeting_ids']:
|
||||
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
|
||||
|
||||
if not kb_info or not kb_info['source_meeting_ids']:
|
||||
return ""
|
||||
|
||||
# 解析会议ID列表
|
||||
meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
|
||||
if meeting_ids:
|
||||
if not meeting_ids:
|
||||
return ""
|
||||
|
||||
# 获取所有会议的总结
|
||||
summaries = []
|
||||
for meeting_id in meeting_ids:
|
||||
cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
summary = cursor.fetchone()
|
||||
if summary and summary['summary']:
|
||||
summaries.append(summary['summary'])
|
||||
source_text = "\n\n---\n\n".join(summaries)
|
||||
|
||||
# Get system prompt
|
||||
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...")
|
||||
system_prompt = self._get_knowledge_task_prompt(cursor)
|
||||
|
||||
# Build final prompt
|
||||
final_prompt = f"{system_prompt}\n\n"
|
||||
if source_text:
|
||||
final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
|
||||
final_prompt += f"用户要求:{user_prompt}"
|
||||
|
||||
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
|
||||
generated_content = self.llm_service._call_llm_api(final_prompt)
|
||||
|
||||
if not generated_content:
|
||||
raise Exception("LLM API call failed or returned empty content")
|
||||
|
||||
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
|
||||
self._save_result_to_db(kb_id, generated_content, cursor)
|
||||
|
||||
self._update_task_in_db(task_id, 'completed', 100, cursor=cursor)
|
||||
self._update_task_status_in_redis(task_id, 'completed', 100)
|
||||
|
||||
connection.commit()
|
||||
print(f"Task {task_id} completed successfully")
|
||||
# 用分隔符拼接多个会议总结
|
||||
return "\n\n---\n\n".join(summaries)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"Task {task_id} failed: {error_msg}")
|
||||
# Use a new connection for error logging to avoid issues with a potentially broken transaction
|
||||
with get_db_connection() as err_conn:
|
||||
err_cursor = err_conn.cursor()
|
||||
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor)
|
||||
err_conn.commit()
|
||||
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
|
||||
print(f"获取会议总结错误: {e}")
|
||||
return ""
|
||||
|
||||
def _get_knowledge_task_prompt(self, cursor) -> str:
|
||||
query = """
|
||||
SELECT p.content
|
||||
FROM prompt_config pc
|
||||
JOIN prompts p ON pc.prompt_id = p.id
|
||||
WHERE pc.task_name = 'KNOWLEDGE_TASK'
|
||||
def _build_prompt(self, source_text: str, user_prompt: str) -> str:
|
||||
"""
|
||||
cursor.execute(query)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['content']
|
||||
else:
|
||||
# Fallback prompt
|
||||
return "Please generate a knowledge base article based on the provided information."
|
||||
构建完整的提示词
|
||||
使用数据库中配置的KNOWLEDGE_TASK提示词模板
|
||||
|
||||
Args:
|
||||
source_text: 源会议总结文本
|
||||
user_prompt: 用户自定义提示词
|
||||
|
||||
Returns:
|
||||
str: 完整的提示词
|
||||
"""
|
||||
# 从数据库获取知识库任务的提示词模板
|
||||
system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||||
|
||||
def _save_result_to_db(self, kb_id: int, content: str, cursor):
|
||||
prompt = f"{system_prompt}\n\n"
|
||||
|
||||
if source_text:
|
||||
prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
|
||||
|
||||
prompt += f"用户要求:{user_prompt}"
|
||||
|
||||
return prompt
|
||||
|
||||
def _save_result_to_db(self, kb_id: int, content: str) -> Optional[int]:
|
||||
"""
|
||||
保存生成结果到数据库
|
||||
|
||||
Args:
|
||||
kb_id: 知识库ID
|
||||
content: 生成的内容
|
||||
|
||||
Returns:
|
||||
Optional[int]: 知识库ID,失败返回None
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
|
||||
cursor.execute(query, (content, kb_id))
|
||||
connection.commit()
|
||||
|
||||
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None):
|
||||
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
|
||||
cursor.execute(query, (status, progress, error_message, status, task_id))
|
||||
print(f"成功保存知识库内容,kb_id: {kb_id}")
|
||||
return kb_id
|
||||
|
||||
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
|
||||
update_data = {
|
||||
'status': status,
|
||||
'progress': str(progress),
|
||||
'updated_at': datetime.now().isoformat()
|
||||
}
|
||||
if message: update_data['message'] = message
|
||||
if error_message: update_data['error_message'] = error_message
|
||||
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
|
||||
except Exception as e:
|
||||
print(f"保存知识库内容错误: {e}")
|
||||
return None
|
||||
|
||||
# --- 状态查询和数据库操作方法 ---
|
||||
|
||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""获取任务状态 - 与 async_llm_service 保持一致"""
|
||||
"""获取任务状态"""
|
||||
try:
|
||||
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
|
||||
if not task_data:
|
||||
|
|
@ -170,6 +229,20 @@ class AsyncKnowledgeBaseService:
|
|||
print(f"Error getting task status: {e}")
|
||||
return {'task_id': task_id, 'status': 'error', 'error_message': str(e)}
|
||||
|
||||
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
|
||||
"""更新Redis中的任务状态"""
|
||||
try:
|
||||
update_data = {
|
||||
'status': status,
|
||||
'progress': str(progress),
|
||||
'updated_at': datetime.now().isoformat()
|
||||
}
|
||||
if message: update_data['message'] = message
|
||||
if error_message: update_data['error_message'] = error_message
|
||||
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
|
||||
except Exception as e:
|
||||
print(f"Error updating task status in Redis: {e}")
|
||||
|
||||
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str):
|
||||
"""保存任务到数据库"""
|
||||
try:
|
||||
|
|
@ -182,6 +255,17 @@ class AsyncKnowledgeBaseService:
|
|||
print(f"Error saving task to database: {e}")
|
||||
raise
|
||||
|
||||
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None):
|
||||
"""更新数据库中的任务状态"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
|
||||
cursor.execute(query, (status, progress, error_message, status, task_id))
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
print(f"Error updating task in database: {e}")
|
||||
|
||||
def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]:
|
||||
"""从数据库获取任务信息"""
|
||||
try:
|
||||
|
|
@ -198,4 +282,5 @@ class AsyncKnowledgeBaseService:
|
|||
print(f"Error getting task from database: {e}")
|
||||
return None
|
||||
|
||||
# 创建全局实例
|
||||
async_kb_service = AsyncKnowledgeBaseService()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
异步LLM服务 - 处理会议总结生成的异步任务
|
||||
异步会议服务 - 处理会议总结生成的异步任务
|
||||
采用FastAPI BackgroundTasks模式
|
||||
"""
|
||||
import uuid
|
||||
|
|
@ -12,8 +12,8 @@ from app.core.config import REDIS_CONFIG
|
|||
from app.core.database import get_db_connection
|
||||
from app.services.llm_service import LLMService
|
||||
|
||||
class AsyncLLMService:
|
||||
"""异步LLM服务类 - 采用FastAPI BackgroundTasks模式"""
|
||||
class AsyncMeetingService:
|
||||
"""异步会议服务类 - 处理会议相关的异步任务"""
|
||||
|
||||
def __init__(self):
|
||||
# 确保redis客户端自动解码响应,代码更简洁
|
||||
|
|
@ -53,7 +53,7 @@ class AsyncLLMService:
|
|||
self.redis_client.hset(f"llm_task:{task_id}", mapping=task_data)
|
||||
self.redis_client.expire(f"llm_task:{task_id}", 86400)
|
||||
|
||||
print(f"LLM summary task created: {task_id} for meeting: {meeting_id}")
|
||||
print(f"Meeting summary task created: {task_id} for meeting: {meeting_id}")
|
||||
return task_id
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -64,7 +64,7 @@ class AsyncLLMService:
|
|||
"""
|
||||
处理单个异步任务的函数,设计为由BackgroundTasks调用。
|
||||
"""
|
||||
print(f"Background task started for LLM task: {task_id}")
|
||||
print(f"Background task started for meeting summary task: {task_id}")
|
||||
try:
|
||||
# 从Redis获取任务数据
|
||||
task_data = self.redis_client.hgetall(f"llm_task:{task_id}")
|
||||
|
|
@ -80,13 +80,13 @@ class AsyncLLMService:
|
|||
|
||||
# 2. 获取会议转录内容
|
||||
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取会议转录内容...")
|
||||
transcript_text = self.llm_service._get_meeting_transcript(meeting_id)
|
||||
transcript_text = self._get_meeting_transcript(meeting_id)
|
||||
if not transcript_text:
|
||||
raise Exception("无法获取会议转录内容")
|
||||
|
||||
# 3. 构建提示词
|
||||
self._update_task_status_in_redis(task_id, 'processing', 40, message="准备AI提示词...")
|
||||
full_prompt = self.llm_service._build_prompt(transcript_text, user_prompt)
|
||||
full_prompt = self._build_prompt(transcript_text, user_prompt)
|
||||
|
||||
# 4. 调用LLM API
|
||||
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在分析会议内容...")
|
||||
|
|
@ -96,7 +96,7 @@ class AsyncLLMService:
|
|||
|
||||
# 5. 保存结果到主表
|
||||
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存总结结果...")
|
||||
self.llm_service._save_summary_to_db(meeting_id, summary_content, user_prompt)
|
||||
self._save_summary_to_db(meeting_id, summary_content, user_prompt)
|
||||
|
||||
# 6. 任务完成
|
||||
self._update_task_in_db(task_id, 'completed', 100, result=summary_content)
|
||||
|
|
@ -110,6 +110,78 @@ class AsyncLLMService:
|
|||
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
|
||||
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
|
||||
|
||||
# --- 会议相关方法 ---
|
||||
|
||||
def _get_meeting_transcript(self, meeting_id: int) -> str:
|
||||
"""从数据库获取会议转录内容"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = """
|
||||
SELECT speaker_tag, start_time_ms, end_time_ms, text_content
|
||||
FROM transcript_segments
|
||||
WHERE meeting_id = %s
|
||||
ORDER BY start_time_ms
|
||||
"""
|
||||
cursor.execute(query, (meeting_id,))
|
||||
segments = cursor.fetchall()
|
||||
|
||||
if not segments:
|
||||
return ""
|
||||
|
||||
# 组装转录文本
|
||||
transcript_lines = []
|
||||
for speaker_tag, start_time, end_time, text in segments:
|
||||
# 将毫秒转换为分:秒格式
|
||||
start_min = start_time // 60000
|
||||
start_sec = (start_time % 60000) // 1000
|
||||
transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}")
|
||||
|
||||
return "\n".join(transcript_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取会议转录内容错误: {e}")
|
||||
return ""
|
||||
|
||||
def _build_prompt(self, transcript_text: str, user_prompt: str) -> str:
|
||||
"""
|
||||
构建完整的提示词
|
||||
使用数据库中配置的MEETING_TASK提示词模板
|
||||
"""
|
||||
# 从数据库获取会议任务的提示词模板
|
||||
system_prompt = self.llm_service.get_task_prompt('MEETING_TASK')
|
||||
|
||||
prompt = f"{system_prompt}\n\n"
|
||||
|
||||
if user_prompt:
|
||||
prompt += f"用户额外要求:{user_prompt}\n\n"
|
||||
|
||||
prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:"
|
||||
|
||||
return prompt
|
||||
|
||||
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]:
|
||||
"""保存总结到数据库 - 更新meetings表的summary、user_prompt和updated_at字段"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
|
||||
# 更新meetings表的summary、user_prompt和updated_at字段
|
||||
update_query = """
|
||||
UPDATE meetings
|
||||
SET summary = %s, user_prompt = %s, updated_at = NOW()
|
||||
WHERE meeting_id = %s
|
||||
"""
|
||||
cursor.execute(update_query, (summary_content, user_prompt, meeting_id))
|
||||
connection.commit()
|
||||
|
||||
print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}")
|
||||
return meeting_id
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存总结到数据库错误: {e}")
|
||||
return None
|
||||
|
||||
# --- 状态查询和数据库操作方法 ---
|
||||
|
||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
|
|
@ -212,4 +284,4 @@ class AsyncLLMService:
|
|||
return None
|
||||
|
||||
# 创建全局实例
|
||||
async_llm_service = AsyncLLMService()
|
||||
async_meeting_service = AsyncMeetingService()
|
||||
|
|
@ -7,6 +7,8 @@ from app.core.database import get_db_connection
|
|||
|
||||
|
||||
class LLMService:
|
||||
"""LLM服务 - 专注于大模型API调用和提示词管理"""
|
||||
|
||||
def __init__(self):
|
||||
# 设置dashscope API key
|
||||
dashscope.api_key = config_module.QWEN_API_KEY
|
||||
|
|
@ -36,123 +38,47 @@ class LLMService:
|
|||
"""动态获取top_p"""
|
||||
return config_module.LLM_CONFIG["top_p"]
|
||||
|
||||
def generate_meeting_summary_stream(self, meeting_id: int, user_prompt: str = "") -> Generator[str, None, None]:
|
||||
def get_task_prompt(self, task_name: str, cursor=None) -> str:
|
||||
"""
|
||||
流式生成会议总结
|
||||
统一的提示词获取方法
|
||||
|
||||
Args:
|
||||
meeting_id: 会议ID
|
||||
user_prompt: 用户额外提示词
|
||||
|
||||
Yields:
|
||||
str: 流式输出的内容片段
|
||||
"""
|
||||
try:
|
||||
# 获取会议转录内容
|
||||
transcript_text = self._get_meeting_transcript(meeting_id)
|
||||
if not transcript_text:
|
||||
yield "error: 无法获取会议转录内容"
|
||||
return
|
||||
|
||||
# 构建完整提示词
|
||||
full_prompt = self._build_prompt(transcript_text, user_prompt)
|
||||
|
||||
# 调用大模型API进行流式生成
|
||||
full_content = ""
|
||||
for chunk in self._call_llm_api_stream(full_prompt):
|
||||
if chunk.startswith("error:"):
|
||||
yield chunk
|
||||
return
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
|
||||
# 保存完整总结到数据库
|
||||
if full_content:
|
||||
self._save_summary_to_db(meeting_id, full_content, user_prompt)
|
||||
|
||||
except Exception as e:
|
||||
print(f"流式生成会议总结错误: {e}")
|
||||
yield f"error: {str(e)}"
|
||||
|
||||
def generate_meeting_summary(self, meeting_id: int, user_prompt: str = "") -> Optional[Dict]:
|
||||
"""
|
||||
生成会议总结(非流式,保持向后兼容)
|
||||
|
||||
Args:
|
||||
meeting_id: 会议ID
|
||||
user_prompt: 用户额外提示词
|
||||
task_name: 任务名称,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等
|
||||
cursor: 数据库游标,如果传入则使用,否则创建新连接
|
||||
|
||||
Returns:
|
||||
包含总结内容的字典,如果失败返回None
|
||||
str: 提示词内容,如果未找到返回默认提示词
|
||||
"""
|
||||
try:
|
||||
# 获取会议转录内容
|
||||
transcript_text = self._get_meeting_transcript(meeting_id)
|
||||
if not transcript_text:
|
||||
return {"error": "无法获取会议转录内容"}
|
||||
|
||||
# 构建完整提示词
|
||||
full_prompt = self._build_prompt(transcript_text, user_prompt)
|
||||
|
||||
# 调用大模型API
|
||||
response = self._call_llm_api(full_prompt)
|
||||
|
||||
if response:
|
||||
# 保存总结到数据库
|
||||
summary_id = self._save_summary_to_db(meeting_id, response, user_prompt)
|
||||
return {
|
||||
"summary_id": summary_id,
|
||||
"content": response,
|
||||
"meeting_id": meeting_id
|
||||
}
|
||||
else:
|
||||
return {"error": "大模型API调用失败"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"生成会议总结错误: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _get_meeting_transcript(self, meeting_id: int) -> str:
|
||||
"""从数据库获取会议转录内容"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = """
|
||||
SELECT speaker_tag, start_time_ms, end_time_ms, text_content
|
||||
FROM transcript_segments
|
||||
WHERE meeting_id = %s
|
||||
ORDER BY start_time_ms
|
||||
SELECT p.content
|
||||
FROM prompt_config pc
|
||||
JOIN prompts p ON pc.prompt_id = p.id
|
||||
WHERE pc.task_name = %s
|
||||
"""
|
||||
cursor.execute(query, (meeting_id,))
|
||||
segments = cursor.fetchall()
|
||||
|
||||
if not segments:
|
||||
return ""
|
||||
if cursor:
|
||||
cursor.execute(query, (task_name,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['content'] if isinstance(result, dict) else result[0]
|
||||
else:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute(query, (task_name,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['content']
|
||||
|
||||
# 组装转录文本
|
||||
transcript_lines = []
|
||||
for speaker_tag, start_time, end_time, text in segments:
|
||||
# 将毫秒转换为分:秒格式
|
||||
start_min = start_time // 60000
|
||||
start_sec = (start_time % 60000) // 1000
|
||||
transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}")
|
||||
# 返回默认提示词
|
||||
return self._get_default_prompt(task_name)
|
||||
|
||||
return "\n".join(transcript_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取会议转录内容错误: {e}")
|
||||
return ""
|
||||
|
||||
def _build_prompt(self, transcript_text: str, user_prompt: str) -> str:
|
||||
"""构建完整的提示词"""
|
||||
prompt = f"{self.system_prompt}\n\n"
|
||||
|
||||
if user_prompt:
|
||||
prompt += f"用户额外要求:{user_prompt}\n\n"
|
||||
|
||||
prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:"
|
||||
|
||||
return prompt
|
||||
def _get_default_prompt(self, task_name: str) -> str:
|
||||
"""获取默认提示词"""
|
||||
default_prompts = {
|
||||
'MEETING_TASK': self.system_prompt, # 使用配置文件中的系统提示词
|
||||
'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。",
|
||||
}
|
||||
return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。")
|
||||
|
||||
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
|
||||
"""流式调用阿里Qwen3大模型API"""
|
||||
|
|
@ -185,7 +111,7 @@ class LLMService:
|
|||
yield f"error: {error_msg}"
|
||||
|
||||
def _call_llm_api(self, prompt: str) -> Optional[str]:
|
||||
"""调用阿里Qwen3大模型API(非流式,保持向后兼容)"""
|
||||
"""调用阿里Qwen3大模型API(非流式)"""
|
||||
try:
|
||||
response = dashscope.Generation.call(
|
||||
model=self.model_name,
|
||||
|
|
@ -205,95 +131,17 @@ class LLMService:
|
|||
print(f"调用大模型API错误: {e}")
|
||||
return None
|
||||
|
||||
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]:
|
||||
"""保存总结到数据库 - 更新meetings表的summary字段"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
|
||||
# 更新meetings表的summary字段
|
||||
update_query = """
|
||||
UPDATE meetings
|
||||
SET summary = %s
|
||||
WHERE meeting_id = %s
|
||||
"""
|
||||
cursor.execute(update_query, (summary_content, meeting_id))
|
||||
connection.commit()
|
||||
|
||||
print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}")
|
||||
return meeting_id
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存总结到数据库错误: {e}")
|
||||
return None
|
||||
|
||||
def get_meeting_summaries(self, meeting_id: int) -> List[Dict]:
|
||||
"""获取会议的当前总结 - 从meetings表的summary字段获取"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = """
|
||||
SELECT summary
|
||||
FROM meetings
|
||||
WHERE meeting_id = %s
|
||||
"""
|
||||
cursor.execute(query, (meeting_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
# 如果有总结内容,返回一个包含当前总结的列表格式(保持API一致性)
|
||||
if result and result[0]:
|
||||
return [{
|
||||
"id": meeting_id,
|
||||
"content": result[0],
|
||||
"user_prompt": "", # meetings表中没有user_prompt字段
|
||||
"created_at": None # meetings表中没有单独的总结创建时间
|
||||
}]
|
||||
else:
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取会议总结错误: {e}")
|
||||
return []
|
||||
|
||||
def get_current_meeting_summary(self, meeting_id: int) -> Optional[str]:
|
||||
"""获取会议当前的总结内容 - 从meetings表的summary字段获取"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = """
|
||||
SELECT summary
|
||||
FROM meetings
|
||||
WHERE meeting_id = %s
|
||||
"""
|
||||
cursor.execute(query, (meeting_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
return result[0] if result and result[0] else None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取会议当前总结错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == '__main__':
|
||||
# 测试LLM服务
|
||||
test_meeting_id = 38
|
||||
test_user_prompt = "请重点关注决策事项和待办任务"
|
||||
|
||||
print("--- 运行LLM服务测试 ---")
|
||||
llm_service = LLMService()
|
||||
|
||||
# 生成总结
|
||||
result = llm_service.generate_meeting_summary(test_meeting_id, test_user_prompt)
|
||||
if result.get("error"):
|
||||
print(f"生成总结失败: {result['error']}")
|
||||
else:
|
||||
print(f"总结生成成功,ID: {result.get('summary_id')}")
|
||||
print(f"总结内容: {result.get('content')[:200]}...")
|
||||
# 测试获取任务提示词
|
||||
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
|
||||
print(f"会议任务提示词: {meeting_prompt[:100]}...")
|
||||
|
||||
# 获取历史总结
|
||||
summaries = llm_service.get_meeting_summaries(test_meeting_id)
|
||||
print(f"获取到 {len(summaries)} 个历史总结")
|
||||
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||||
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
|
||||
|
||||
print("--- LLM服务测试完成 ---")
|
||||
|
|
@ -0,0 +1,218 @@
|
|||
"""
|
||||
声纹服务 - 处理用户声纹采集、存储和验证
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import wave
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.database import get_db_connection
|
||||
import app.core.config as config_module
|
||||
|
||||
|
||||
class VoiceprintService:
|
||||
"""声纹服务类 - 同步处理声纹采集"""
|
||||
|
||||
def __init__(self):
|
||||
self.voiceprint_dir = config_module.VOICEPRINT_DIR
|
||||
self.voiceprint_config = config_module.VOICEPRINT_CONFIG
|
||||
|
||||
def get_user_voiceprint_status(self, user_id: int) -> Dict:
|
||||
"""
|
||||
获取用户声纹状态
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
Dict: 声纹状态信息
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
query = """
|
||||
SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at
|
||||
FROM user_voiceprint
|
||||
WHERE user_id = %s
|
||||
"""
|
||||
cursor.execute(query, (user_id,))
|
||||
voiceprint = cursor.fetchone()
|
||||
|
||||
if voiceprint:
|
||||
return {
|
||||
"has_voiceprint": True,
|
||||
"vp_id": voiceprint['vp_id'],
|
||||
"file_path": voiceprint['file_path'],
|
||||
"duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None,
|
||||
"collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"has_voiceprint": False,
|
||||
"vp_id": None,
|
||||
"file_path": None,
|
||||
"duration_seconds": None,
|
||||
"collected_at": None
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"获取声纹状态错误: {e}")
|
||||
raise e
|
||||
|
||||
def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict:
|
||||
"""
|
||||
保存声纹文件并提取特征向量
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
audio_file_path: 音频文件路径
|
||||
file_size: 文件大小
|
||||
|
||||
Returns:
|
||||
Dict: 保存结果
|
||||
"""
|
||||
try:
|
||||
# 1. 获取音频时长
|
||||
duration = self._get_audio_duration(audio_file_path)
|
||||
|
||||
# 2. 提取声纹向量(调用FunASR)
|
||||
vector_data = self._extract_voiceprint_vector(audio_file_path)
|
||||
|
||||
# 3. 保存到数据库
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 检查用户是否已有声纹
|
||||
cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,))
|
||||
existing = cursor.fetchone()
|
||||
|
||||
# 计算相对路径
|
||||
relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR))
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
update_query = """
|
||||
UPDATE user_voiceprint
|
||||
SET file_path = %s, file_size = %s, duration_seconds = %s,
|
||||
vector_data = %s, updated_at = NOW()
|
||||
WHERE user_id = %s
|
||||
"""
|
||||
cursor.execute(update_query, (
|
||||
relative_path, file_size, duration,
|
||||
json.dumps(vector_data) if vector_data else None,
|
||||
user_id
|
||||
))
|
||||
vp_id = existing['vp_id']
|
||||
else:
|
||||
# 插入新记录
|
||||
insert_query = """
|
||||
INSERT INTO user_voiceprint
|
||||
(user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at)
|
||||
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
user_id, relative_path, file_size, duration,
|
||||
json.dumps(vector_data) if vector_data else None
|
||||
))
|
||||
vp_id = cursor.lastrowid
|
||||
|
||||
connection.commit()
|
||||
|
||||
return {
|
||||
"vp_id": vp_id,
|
||||
"user_id": user_id,
|
||||
"file_path": relative_path,
|
||||
"file_size": file_size,
|
||||
"duration_seconds": duration,
|
||||
"has_vector": vector_data is not None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存声纹错误: {e}")
|
||||
raise e
|
||||
|
||||
def delete_voiceprint(self, user_id: int) -> bool:
|
||||
"""
|
||||
删除用户声纹
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 获取文件路径
|
||||
cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,))
|
||||
voiceprint = cursor.fetchone()
|
||||
|
||||
if voiceprint:
|
||||
# 构建完整文件路径
|
||||
relative_path = voiceprint['file_path']
|
||||
if relative_path.startswith('/'):
|
||||
relative_path = relative_path.lstrip('/')
|
||||
file_path = config_module.BASE_DIR / relative_path
|
||||
|
||||
# 删除数据库记录
|
||||
cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,))
|
||||
connection.commit()
|
||||
|
||||
# 删除文件
|
||||
if file_path.exists():
|
||||
os.remove(file_path)
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"删除声纹错误: {e}")
|
||||
raise e
|
||||
|
||||
def _get_audio_duration(self, audio_file_path: str) -> float:
|
||||
"""
|
||||
获取音频文件时长
|
||||
|
||||
Args:
|
||||
audio_file_path: 音频文件路径
|
||||
|
||||
Returns:
|
||||
float: 时长(秒)
|
||||
"""
|
||||
try:
|
||||
with wave.open(audio_file_path, 'rb') as wav_file:
|
||||
frames = wav_file.getnframes()
|
||||
rate = wav_file.getframerate()
|
||||
duration = frames / float(rate)
|
||||
return round(duration, 2)
|
||||
except Exception as e:
|
||||
print(f"获取音频时长错误: {e}")
|
||||
return 10.0 # 默认返回10秒
|
||||
|
||||
def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]:
|
||||
"""
|
||||
提取声纹特征向量(调用FunASR)
|
||||
|
||||
Args:
|
||||
audio_file_path: 音频文件路径
|
||||
|
||||
Returns:
|
||||
Optional[list]: 声纹向量(192维),失败返回None
|
||||
"""
|
||||
# TODO: 集成FunASR的说话人识别模型
|
||||
# 使用 speech_campplus_sv_zh-cn_16k-common 模型
|
||||
# 返回192维的embedding向量
|
||||
|
||||
print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}")
|
||||
|
||||
# 暂时返回None,等待FunASR集成
|
||||
# 集成后应该返回类似: [0.123, -0.456, 0.789, ...]
|
||||
return None
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
voiceprint_service = VoiceprintService()
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"model_name": "qwen-plus",
|
||||
"system_prompt": "你是一个专业的会议记录分析助手。请根据提供的会议转录内容,生成简洁明了的会议总结。\n\n总结包括五个部分(名称严格一致,生成为MD二级目录):\n1. 会议概述 - 简要说明会议的主要目的和背景(生成MD引用)\n2. 主要讨论点 - 列出会议中讨论的重要话题和内容\n3. 决策事项 - 明确记录会议中做出的决定和结论\n4. 待办事项 - 列出需要后续跟进的任务和责任人\n5. 关键信息 - 其他重要的信息点\n\n输出要求:\n- 保持客观中性,不添加个人观点\n- 使用简洁、准确的中文表达\n- 按重要性排序各项内容\n- 如果某个部分没有相关内容,可以说明\"无相关内容\"\n- 总字数控制在500字以内",
|
||||
"DEFAULT_RESET_PASSWORD": "111111",
|
||||
"MAX_FILE_SIZE": 209715200,
|
||||
"DEFAULT_RESET_PASSWORD": "123456",
|
||||
"MAX_FILE_SIZE": 208666624,
|
||||
"MAX_IMAGE_SIZE": 10485760
|
||||
}
|
||||
3
main.py
3
main.py
|
|
@ -2,7 +2,7 @@ import uvicorn
|
|||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads
|
||||
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads, voiceprint
|
||||
from app.core.config import UPLOAD_DIR, API_CONFIG
|
||||
from app.api.endpoints.admin import load_system_config
|
||||
import os
|
||||
|
|
@ -39,6 +39,7 @@ app.include_router(tasks.router, prefix="/api", tags=["Tasks"])
|
|||
app.include_router(prompts.router, prefix="/api", tags=["Prompts"])
|
||||
app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"])
|
||||
app.include_router(client_downloads.router, prefix="/api/clients", tags=["ClientDownloads"])
|
||||
app.include_router(voiceprint.router, prefix="/api", tags=["Voiceprint"])
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ import sys
|
|||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.services.async_llm_service import AsyncLLMService
|
||||
from app.services.async_meeting_service import AsyncMeetingService
|
||||
|
||||
# 创建服务实例
|
||||
service = AsyncLLMService()
|
||||
service = AsyncMeetingService()
|
||||
|
||||
# 创建测试任务
|
||||
meeting_id = 38
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ import time
|
|||
import threading
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.services.async_llm_service import AsyncLLMService
|
||||
from app.services.async_meeting_service import AsyncMeetingService
|
||||
|
||||
# 创建服务实例
|
||||
service = AsyncLLMService()
|
||||
service = AsyncMeetingService()
|
||||
|
||||
# 直接调用处理任务方法测试
|
||||
print("测试直接调用_process_tasks方法...")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,141 @@
|
|||
"""
|
||||
声纹采集API测试脚本
|
||||
|
||||
使用方法:
|
||||
1. 确保后端服务正在运行
|
||||
2. 修改 USER_ID 和 TOKEN 为实际值
|
||||
3. 准备一个10秒的WAV音频文件
|
||||
4. 运行: python test_voiceprint_api.py
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
# 配置
|
||||
BASE_URL = "http://localhost:8000/api"
|
||||
USER_ID = 1 # 修改为实际用户ID
|
||||
TOKEN = "" # 登录后获取的token
|
||||
|
||||
# 请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {TOKEN}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
def test_get_template():
|
||||
"""测试获取朗读模板"""
|
||||
print("\n=== 测试1: 获取朗读模板 ===")
|
||||
url = f"{BASE_URL}/voiceprint/template"
|
||||
response = requests.get(url, headers=headers)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_get_status(user_id):
|
||||
"""测试获取声纹状态"""
|
||||
print(f"\n=== 测试2: 获取用户 {user_id} 的声纹状态 ===")
|
||||
url = f"{BASE_URL}/voiceprint/{user_id}"
|
||||
response = requests.get(url, headers=headers)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_upload_voiceprint(user_id, audio_file_path):
|
||||
"""测试上传声纹"""
|
||||
print(f"\n=== 测试3: 上传声纹音频 ===")
|
||||
url = f"{BASE_URL}/voiceprint/{user_id}"
|
||||
|
||||
# 移除Content-Type,让requests自动设置multipart/form-data
|
||||
upload_headers = {
|
||||
"Authorization": f"Bearer {TOKEN}"
|
||||
}
|
||||
|
||||
with open(audio_file_path, 'rb') as f:
|
||||
files = {'audio_file': (audio_file_path.split('/')[-1], f, 'audio/wav')}
|
||||
response = requests.post(url, headers=upload_headers, files=files)
|
||||
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_delete_voiceprint(user_id):
|
||||
"""测试删除声纹"""
|
||||
print(f"\n=== 测试4: 删除用户 {user_id} 的声纹 ===")
|
||||
url = f"{BASE_URL}/voiceprint/{user_id}"
|
||||
response = requests.delete(url, headers=headers)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
|
||||
return response.json()
|
||||
|
||||
|
||||
def login(username, password):
|
||||
"""登录获取token"""
|
||||
print("\n=== 登录获取Token ===")
|
||||
url = f"{BASE_URL}/auth/login"
|
||||
data = {
|
||||
"username": username,
|
||||
"password": password
|
||||
}
|
||||
response = requests.post(url, json=data)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get('code') == '200':
|
||||
token = result['data']['token']
|
||||
print(f"登录成功,Token: {token[:20]}...")
|
||||
return token
|
||||
else:
|
||||
print(f"登录失败: {result.get('message')}")
|
||||
return None
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("声纹采集API测试脚本")
|
||||
print("=" * 60)
|
||||
|
||||
# 步骤1: 登录(如果没有token)
|
||||
if not TOKEN:
|
||||
print("\n请先登录获取Token...")
|
||||
username = input("用户名: ")
|
||||
password = input("密码: ")
|
||||
TOKEN = login(username, password)
|
||||
if TOKEN:
|
||||
headers["Authorization"] = f"Bearer {TOKEN}"
|
||||
else:
|
||||
print("登录失败,退出测试")
|
||||
exit(1)
|
||||
|
||||
# 步骤2: 测试获取朗读模板
|
||||
test_get_template()
|
||||
|
||||
# 步骤3: 测试获取声纹状态
|
||||
test_get_status(USER_ID)
|
||||
|
||||
# 步骤4: 测试上传声纹(需要准备音频文件)
|
||||
audio_file = input("\n请输入WAV音频文件路径 (回车跳过上传测试): ")
|
||||
if audio_file.strip():
|
||||
test_upload_voiceprint(USER_ID, audio_file.strip())
|
||||
|
||||
# 上传后再次查看状态
|
||||
print("\n=== 上传后再次查看状态 ===")
|
||||
test_get_status(USER_ID)
|
||||
|
||||
# 步骤5: 测试删除声纹
|
||||
confirm = input("\n是否测试删除声纹? (yes/no): ")
|
||||
if confirm.lower() == 'yes':
|
||||
test_delete_voiceprint(USER_ID)
|
||||
|
||||
# 删除后再次查看状态
|
||||
print("\n=== 删除后再次查看状态 ===")
|
||||
test_get_status(USER_ID)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成")
|
||||
print("=" * 60)
|
||||
Loading…
Reference in New Issue