Compare commits

...

2 Commits

Author SHA1 Message Date
mula.liu 2f36474f4d 1.0.3 2025-10-31 14:54:54 +08:00
mula.liu 976ea854b6 整理了会议和知识库的代码结构 2025-10-28 19:30:09 +08:00
19 changed files with 913 additions and 356 deletions

BIN
.DS_Store vendored

Binary file not shown.

BIN
app.zip

Binary file not shown.

View File

@ -9,29 +9,25 @@ import datetime
router = APIRouter()
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
"""
处理标签查询已存在的标签如果提供了 creator_id 则创建不存在的标签
"""
if not tag_string:
return []
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
if not tag_names:
return []
placeholders = ','.join(['%s'] * len(tag_names))
select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})"
cursor.execute(select_query, tuple(tag_names))
# 如果提供了 creator_id则创建不存在的标签
if creator_id:
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
# 查询所有标签信息
format_strings = ', '.join(['%s'] * len(tag_names))
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
tags_data = cursor.fetchall()
existing_tags = {tag['name']: tag for tag in tags_data}
new_tags = [name for name in tag_names if name not in existing_tags]
if new_tags:
insert_query = "INSERT INTO tags (name) VALUES (%s)"
cursor.executemany(insert_query, [(name,) for name in new_tags])
# Re-fetch all tags to get their IDs and default colors
cursor.execute(select_query, tuple(tag_names))
tags_data = cursor.fetchall()
return [Tag(**tag) for tag in tags_data]
@ -83,7 +79,8 @@ def get_knowledge_bases(
kb_list = []
for kb_data in kbs_data:
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# 列表页不需要处理 tags直接使用字符串
# kb_data['tags'] 保持原样(逗号分隔的标签名称字符串)
# Count source meetings - filter empty strings
if kb_data.get('source_meeting_ids'):
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
@ -122,7 +119,7 @@ def create_knowledge_base(
request.is_shared,
request.source_meeting_ids,
request.user_prompt,
request.tags,
request.tags, # 创建时 tags 应该为 None 或空字符串
now,
now
))
@ -171,7 +168,8 @@ def get_knowledge_base_detail(
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Access denied")
# Process tags
# Process tags - 获取标签的完整信息(包括颜色)
# 详情页不需要创建新标签,所以不传 creator_id
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# Get source meetings details
@ -220,6 +218,10 @@ def update_knowledge_base(
if kb['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
# 使用 _process_tags 处理标签(会自动创建新标签)
if request.tags:
_process_tags(cursor, request.tags, current_user['user_id'])
# Update the knowledge base
now = datetime.datetime.utcnow()
update_query = """

View File

@ -5,7 +5,7 @@ from app.core.config import BASE_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSION
import app.core.config as config_module
from app.services.llm_service import LLMService
from app.services.async_transcription_service import AsyncTranscriptionService
from app.services.async_llm_service import async_llm_service
from app.services.async_meeting_service import async_meeting_service
from app.core.auth import get_current_user
from app.core.response import create_api_response
from typing import List, Optional
@ -23,14 +23,22 @@ transcription_service = AsyncTranscriptionService()
class GenerateSummaryRequest(BaseModel):
user_prompt: Optional[str] = ""
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
"""
处理标签查询已存在的标签如果提供了 creator_id 则创建不存在的标签
"""
if not tag_string:
return []
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
if not tag_names:
return []
insert_ignore_query = "INSERT IGNORE INTO tags (name) VALUES (%s)"
cursor.executemany(insert_ignore_query, [(name,) for name in tag_names])
# 如果提供了 creator_id则创建不存在的标签
if creator_id:
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
# 查询所有标签信息
format_strings = ', '.join(['%s'] * len(tag_names))
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
tags_data = cursor.fetchall()
@ -125,10 +133,9 @@ def get_meeting_transcript(meeting_id: int, current_user: dict = Depends(get_cur
def create_meeting(meeting_request: CreateMeetingRequest, current_user: dict = Depends(get_current_user)):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 使用 _process_tags 来处理标签创建
if meeting_request.tags:
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()]
if tag_names:
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
_process_tags(cursor, meeting_request.tags, current_user['user_id'])
meeting_query = 'INSERT INTO meetings (user_id, title, meeting_time, summary, tags, created_at) VALUES (%s, %s, %s, %s, %s, %s)'
cursor.execute(meeting_query, (meeting_request.user_id, meeting_request.title, meeting_request.meeting_time, None, meeting_request.tags, datetime.now().isoformat()))
meeting_id = cursor.lastrowid
@ -147,10 +154,9 @@ def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest, curre
return create_api_response(code="404", message="Meeting not found")
if meeting['user_id'] != current_user['user_id']:
return create_api_response(code="403", message="Permission denied")
# 使用 _process_tags 来处理标签创建
if meeting_request.tags:
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()]
if tag_names:
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
_process_tags(cursor, meeting_request.tags, current_user['user_id'])
update_query = 'UPDATE meetings SET title = %s, meeting_time = %s, summary = %s, tags = %s WHERE meeting_id = %s'
cursor.execute(update_query, (meeting_request.title, meeting_request.meeting_time, meeting_request.summary, meeting_request.tags, meeting_id))
cursor.execute("DELETE FROM attendees WHERE meeting_id = %s", (meeting_id,))
@ -449,8 +455,8 @@ def generate_meeting_summary_async(meeting_id: int, request: GenerateSummaryRequ
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
if not cursor.fetchone():
return create_api_response(code="404", message="Meeting not found")
task_id = async_llm_service.start_summary_generation(meeting_id, request.user_prompt)
background_tasks.add_task(async_llm_service._process_task, task_id)
task_id = async_meeting_service.start_summary_generation(meeting_id, request.user_prompt)
background_tasks.add_task(async_meeting_service._process_task, task_id)
return create_api_response(code="200", message="Summary generation task has been accepted.", data={
"task_id": task_id, "status": "pending", "meeting_id": meeting_id
})
@ -465,7 +471,7 @@ def get_meeting_llm_tasks(meeting_id: int, current_user: dict = Depends(get_curr
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
if not cursor.fetchone():
return create_api_response(code="404", message="Meeting not found")
tasks = async_llm_service.get_meeting_llm_tasks(meeting_id)
tasks = async_meeting_service.get_meeting_llm_tasks(meeting_id)
return create_api_response(code="200", message="LLM tasks retrieved successfully", data={
"tasks": tasks, "total": len(tasks)
})

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
from pydantic import BaseModel
from typing import List, Optional
from app.core.auth import get_current_admin_user
from app.core.auth import get_current_user
from app.core.database import get_db_connection
from app.core.response import create_api_response
@ -23,14 +23,14 @@ class PromptListResponse(BaseModel):
total: int
@router.post("/prompts")
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)):
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
"""Create a new prompt."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
cursor.execute(
"INSERT INTO prompts (name, tags, content) VALUES (%s, %s, %s)",
(prompt.name, prompt.tags, prompt.content)
"INSERT INTO prompts (name, tags, content, creator_id) VALUES (%s, %s, %s, %s)",
(prompt.name, prompt.tags, prompt.content, current_user["user_id"])
)
connection.commit()
new_id = cursor.lastrowid
@ -41,23 +41,27 @@ def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_adm
return create_api_response(code="500", message=f"创建提示词失败: {e}")
@router.get("/prompts")
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_admin_user)):
"""Get a paginated list of prompts."""
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_user)):
"""Get a paginated list of prompts filtered by current user."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute("SELECT COUNT(*) as total FROM prompts")
# 只获取当前用户创建的提示词
cursor.execute(
"SELECT COUNT(*) as total FROM prompts WHERE creator_id = %s",
(current_user["user_id"],)
)
total = cursor.fetchone()['total']
offset = (page - 1) * size
cursor.execute(
"SELECT id, name, tags, content, created_at FROM prompts ORDER BY created_at DESC LIMIT %s OFFSET %s",
(size, offset)
"SELECT id, name, tags, content, created_at FROM prompts WHERE creator_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s",
(current_user["user_id"], size, offset)
)
prompts = cursor.fetchall()
return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total})
@router.get("/prompts/{prompt_id}")
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)):
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
"""Get a single prompt by its ID."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
@ -68,7 +72,7 @@ def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_us
return create_api_response(code="200", message="获取提示词成功", data=prompt)
@router.put("/prompts/{prompt_id}")
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)):
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
"""Update an existing prompt."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
@ -87,12 +91,24 @@ def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends
return create_api_response(code="500", message=f"更新提示词失败: {e}")
@router.delete("/prompts/{prompt_id}")
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)):
"""Delete a prompt."""
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
"""Delete a prompt. Only the creator can delete their own prompts."""
with get_db_connection() as connection:
cursor = connection.cursor()
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
if cursor.rowcount == 0:
cursor = connection.cursor(dictionary=True)
# 首先检查提示词是否存在以及是否属于当前用户
cursor.execute(
"SELECT creator_id FROM prompts WHERE id = %s",
(prompt_id,)
)
prompt = cursor.fetchone()
if not prompt:
return create_api_response(code="404", message="提示词不存在")
if prompt['creator_id'] != current_user["user_id"]:
return create_api_response(code="403", message="无权删除其他用户的提示词")
# 删除提示词
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
connection.commit()
return create_api_response(code="200", message="提示词删除成功")

View File

@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends
from app.core.database import get_db_connection
from app.core.response import create_api_response
from app.core.auth import get_current_user
from app.models.models import Tag
from typing import List
import mysql.connector
@ -24,16 +25,16 @@ def get_all_tags():
return create_api_response(code="500", message="获取标签失败")
@router.post("/tags/")
def create_tag(tag_in: Tag):
def create_tag(tag_in: Tag, current_user: dict = Depends(get_current_user)):
"""_summary_
创建一个新标签
创建一个新标签并记录创建者
"""
query = "INSERT INTO tags (name, color) VALUES (%s, %s)"
query = "INSERT INTO tags (name, color, creator_id) VALUES (%s, %s, %s)"
try:
with get_db_connection() as connection:
with connection.cursor(dictionary=True) as cursor:
try:
cursor.execute(query, (tag_in.name, tag_in.color))
cursor.execute(query, (tag_in.name, tag_in.color, current_user["user_id"]))
connection.commit()
tag_id = cursor.lastrowid
new_tag = {"id": tag_id, "name": tag_in.name, "color": tag_in.color}

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
from app.core.auth import get_current_user
from app.core.response import create_api_response
from app.services.async_transcription_service import AsyncTranscriptionService
from app.services.async_llm_service import async_llm_service
from app.services.async_meeting_service import async_meeting_service
router = APIRouter()
@ -23,7 +23,7 @@ def get_transcription_task_status(task_id: str, current_user: dict = Depends(get
def get_llm_task_status(task_id: str, current_user: dict = Depends(get_current_user)):
"""获取LLM总结任务状态包括进度"""
try:
status = async_llm_service.get_task_status(task_id)
status = async_meeting_service.get_task_status(task_id)
if status.get('status') == 'not_found':
return create_api_response(code="404", message="Task not found")
return create_api_response(code="200", message="Task status retrieved", data=status)

View File

@ -0,0 +1,131 @@
"""
声纹采集API接口
"""
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
from typing import Optional
from pathlib import Path
import datetime
from app.models.models import VoiceprintStatus, VoiceprintTemplate
from app.core.auth import get_current_user
from app.core.response import create_api_response
from app.services.voiceprint_service import voiceprint_service
import app.core.config as config_module
router = APIRouter()
@router.get("/voiceprint/template", response_model=None)
def get_voiceprint_template(current_user: dict = Depends(get_current_user)):
"""
获取声纹采集朗读模板配置
权限需要登录
"""
try:
template_data = VoiceprintTemplate(**config_module.VOICEPRINT_CONFIG)
return create_api_response(code="200", message="获取朗读模板成功", data=template_data.dict())
except Exception as e:
return create_api_response(code="500", message=f"获取朗读模板失败: {str(e)}")
@router.get("/voiceprint/{user_id}", response_model=None)
def get_voiceprint_status(user_id: int, current_user: dict = Depends(get_current_user)):
"""
获取用户声纹采集状态
权限用户只能查询自己的声纹状态管理员可查询所有
"""
# 权限检查:只能查询自己的声纹,或者是管理员
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限查询其他用户的声纹状态")
try:
status_data = voiceprint_service.get_user_voiceprint_status(user_id)
return create_api_response(code="200", message="获取声纹状态成功", data=status_data)
except Exception as e:
return create_api_response(code="500", message=f"获取声纹状态失败: {str(e)}")
@router.post("/voiceprint/{user_id}", response_model=None)
async def upload_voiceprint(
user_id: int,
audio_file: UploadFile = File(...),
current_user: dict = Depends(get_current_user)
):
"""
上传声纹音频文件同步处理
权限用户只能上传自己的声纹管理员可操作所有
"""
# 权限检查
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限上传其他用户的声纹")
# 检查文件格式
file_ext = Path(audio_file.filename).suffix.lower()
if file_ext not in config_module.ALLOWED_VOICEPRINT_EXTENSIONS:
return create_api_response(
code="400",
message=f"不支持的文件格式,仅支持: {', '.join(config_module.ALLOWED_VOICEPRINT_EXTENSIONS)}"
)
# 检查文件大小
max_size = config_module.VOICEPRINT_CONFIG.get('max_file_size', 5242880) # 默认5MB
content = await audio_file.read()
file_size = len(content)
if file_size > max_size:
return create_api_response(
code="400",
message=f"文件过大,最大允许 {max_size / 1024 / 1024:.1f}MB"
)
try:
# 确保用户目录存在
user_voiceprint_dir = config_module.VOICEPRINT_DIR / str(user_id)
user_voiceprint_dir.mkdir(parents=True, exist_ok=True)
# 生成文件名:时间戳.wav
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{timestamp}.wav"
file_path = user_voiceprint_dir / filename
# 保存文件
with open(file_path, "wb") as f:
f.write(content)
# 调用服务处理声纹(提取特征向量,保存到数据库)
result = voiceprint_service.save_voiceprint(user_id, str(file_path), file_size)
return create_api_response(code="200", message="声纹采集成功", data=result)
except Exception as e:
# 如果出错,删除已上传的文件
if 'file_path' in locals() and Path(file_path).exists():
Path(file_path).unlink()
return create_api_response(code="500", message=f"声纹采集失败: {str(e)}")
@router.delete("/voiceprint/{user_id}", response_model=None)
def delete_voiceprint(user_id: int, current_user: dict = Depends(get_current_user)):
"""
删除用户声纹数据允许重新采集
权限用户只能删除自己的声纹管理员可操作所有
"""
# 权限检查
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
return create_api_response(code="403", message="无权限删除其他用户的声纹")
try:
success = voiceprint_service.delete_voiceprint(user_id)
if success:
return create_api_response(code="200", message="声纹删除成功")
else:
return create_api_response(code="404", message="未找到该用户的声纹数据")
except Exception as e:
return create_api_response(code="500", message=f"删除声纹失败: {str(e)}")

View File

@ -1,4 +1,5 @@
import os
import json
from pathlib import Path
# 基础路径配置
@ -6,10 +7,12 @@ BASE_DIR = Path(__file__).parent.parent.parent
UPLOAD_DIR = BASE_DIR / "uploads"
AUDIO_DIR = UPLOAD_DIR / "audio"
MARKDOWN_DIR = UPLOAD_DIR / "markdown"
VOICEPRINT_DIR = UPLOAD_DIR / "voiceprint"
# 文件上传配置
ALLOWED_EXTENSIONS = {".mp3", ".wav", ".m4a", ".mpeg", ".mp4"}
ALLOWED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
ALLOWED_VOICEPRINT_EXTENSIONS = {".wav"}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
@ -17,6 +20,7 @@ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
UPLOAD_DIR.mkdir(exist_ok=True)
AUDIO_DIR.mkdir(exist_ok=True)
MARKDOWN_DIR.mkdir(exist_ok=True)
VOICEPRINT_DIR.mkdir(exist_ok=True)
# 数据库配置
DATABASE_CONFIG = {
@ -82,3 +86,12 @@ LLM_CONFIG = {
# 密码重置配置
DEFAULT_RESET_PASSWORD = os.getenv('DEFAULT_RESET_PASSWORD', '111111')
# 加载系统配置文件
# 默认声纹配置
VOICEPRINT_CONFIG = {
"template_text": "我正在进行声纹采集,这段语音将用于身份识别和验证。\n\n声纹技术能够准确识别每个人独特的声音特征。",
"duration_seconds": 12,
"sample_rate": 16000,
"channels": 1
}

View File

@ -128,7 +128,7 @@ class KnowledgeBase(BaseModel):
is_shared: bool
source_meeting_ids: Optional[str] = None
user_prompt: Optional[str] = None
tags: Optional[List[Tag]] = []
tags: Union[Optional[str], Optional[List[Tag]]] = None # 支持字符串或Tag列表
created_at: datetime.datetime
updated_at: datetime.datetime
source_meeting_count: Optional[int] = 0
@ -204,3 +204,26 @@ class UpdateClientDownloadRequest(BaseModel):
class ClientDownloadListResponse(BaseModel):
clients: List[ClientDownload]
total: int
# 声纹采集相关模型
class VoiceprintInfo(BaseModel):
vp_id: int
user_id: int
file_path: str
file_size: Optional[int] = None
duration_seconds: Optional[float] = None
collected_at: datetime.datetime
updated_at: datetime.datetime
class VoiceprintStatus(BaseModel):
has_voiceprint: bool
vp_id: Optional[int] = None
file_path: Optional[str] = None
duration_seconds: Optional[float] = None
collected_at: Optional[datetime.datetime] = None
class VoiceprintTemplate(BaseModel):
template_text: str
duration_seconds: int
sample_rate: int
channels: int

View File

@ -1,3 +1,7 @@
"""
异步知识库服务 - 处理知识库生成的异步任务
采用FastAPI BackgroundTasks模式
"""
import uuid
from datetime import datetime
from typing import Optional, Dict, Any, List
@ -7,6 +11,7 @@ from app.core.database import get_db_connection
from app.services.llm_service import LLMService
class AsyncKnowledgeBaseService:
"""异步知识库服务类 - 处理知识库相关的异步任务"""
def __init__(self):
from app.core.config import REDIS_CONFIG
@ -16,6 +21,19 @@ class AsyncKnowledgeBaseService:
self.llm_service = LLMService()
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str:
"""
创建异步知识库生成任务
Args:
user_id: 用户ID
kb_id: 知识库ID
user_prompt: 用户提示词
source_meeting_ids: 源会议ID列表
cursor: 数据库游标可选
Returns:
str: 任务ID
"""
task_id = str(uuid.uuid4())
# If a cursor is passed, use it directly to avoid creating a new transaction
@ -47,8 +65,12 @@ class AsyncKnowledgeBaseService:
return task_id
def _process_task(self, task_id: str):
"""
处理单个异步任务的函数设计为由BackgroundTasks调用
"""
print(f"Background task started for knowledge base task: {task_id}")
try:
# 从Redis获取任务数据
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data:
print(f"Error: Task {task_id} not found in Redis for processing.")
@ -57,99 +79,136 @@ class AsyncKnowledgeBaseService:
kb_id = int(task_data['kb_id'])
user_prompt = task_data.get('user_prompt', '')
# 1. 更新状态为processing
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
# 2. 获取关联的会议总结
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
source_text = self._get_meeting_summaries(kb_id)
# 3. 构建提示词
self._update_task_status_in_redis(task_id, 'processing', 30, message="准备AI提示词...")
full_prompt = self._build_prompt(source_text, user_prompt)
# 4. 调用LLM API
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
generated_content = self.llm_service._call_llm_api(full_prompt)
if not generated_content:
raise Exception("LLM API调用失败或返回空内容")
# 5. 保存结果到数据库
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
self._save_result_to_db(kb_id, generated_content)
# 6. 任务完成
self._update_task_in_db(task_id, 'completed', 100)
self._update_task_status_in_redis(task_id, 'completed', 100)
print(f"Task {task_id} completed successfully")
except Exception as e:
error_msg = str(e)
print(f"Task {task_id} failed: {error_msg}")
# 更新失败状态
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
# --- 知识库相关方法 ---
def _get_meeting_summaries(self, kb_id: int) -> str:
"""
从数据库获取知识库关联的会议总结
Args:
kb_id: 知识库ID
Returns:
str: 拼接后的会议总结文本
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# Get source meeting summaries
source_text = ""
# 获取知识库的源会议ID列表
cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
kb_info = cursor.fetchone()
if kb_info and kb_info['source_meeting_ids']:
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
if not kb_info or not kb_info['source_meeting_ids']:
return ""
# 解析会议ID列表
meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
if meeting_ids:
if not meeting_ids:
return ""
# 获取所有会议的总结
summaries = []
for meeting_id in meeting_ids:
cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
summary = cursor.fetchone()
if summary and summary['summary']:
summaries.append(summary['summary'])
source_text = "\n\n---\n\n".join(summaries)
# Get system prompt
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...")
system_prompt = self._get_knowledge_task_prompt(cursor)
# Build final prompt
final_prompt = f"{system_prompt}\n\n"
if source_text:
final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
final_prompt += f"用户要求:{user_prompt}"
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
generated_content = self.llm_service._call_llm_api(final_prompt)
if not generated_content:
raise Exception("LLM API call failed or returned empty content")
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
self._save_result_to_db(kb_id, generated_content, cursor)
self._update_task_in_db(task_id, 'completed', 100, cursor=cursor)
self._update_task_status_in_redis(task_id, 'completed', 100)
connection.commit()
print(f"Task {task_id} completed successfully")
# 用分隔符拼接多个会议总结
return "\n\n---\n\n".join(summaries)
except Exception as e:
error_msg = str(e)
print(f"Task {task_id} failed: {error_msg}")
# Use a new connection for error logging to avoid issues with a potentially broken transaction
with get_db_connection() as err_conn:
err_cursor = err_conn.cursor()
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor)
err_conn.commit()
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
print(f"获取会议总结错误: {e}")
return ""
def _get_knowledge_task_prompt(self, cursor) -> str:
query = """
SELECT p.content
FROM prompt_config pc
JOIN prompts p ON pc.prompt_id = p.id
WHERE pc.task_name = 'KNOWLEDGE_TASK'
def _build_prompt(self, source_text: str, user_prompt: str) -> str:
"""
cursor.execute(query)
result = cursor.fetchone()
if result:
return result['content']
else:
# Fallback prompt
return "Please generate a knowledge base article based on the provided information."
构建完整的提示词
使用数据库中配置的KNOWLEDGE_TASK提示词模板
Args:
source_text: 源会议总结文本
user_prompt: 用户自定义提示词
Returns:
str: 完整的提示词
"""
# 从数据库获取知识库任务的提示词模板
system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK')
def _save_result_to_db(self, kb_id: int, content: str, cursor):
prompt = f"{system_prompt}\n\n"
if source_text:
prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
prompt += f"用户要求:{user_prompt}"
return prompt
def _save_result_to_db(self, kb_id: int, content: str) -> Optional[int]:
"""
保存生成结果到数据库
Args:
kb_id: 知识库ID
content: 生成的内容
Returns:
Optional[int]: 知识库ID失败返回None
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
cursor.execute(query, (content, kb_id))
connection.commit()
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None):
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
cursor.execute(query, (status, progress, error_message, status, task_id))
print(f"成功保存知识库内容kb_id: {kb_id}")
return kb_id
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
update_data = {
'status': status,
'progress': str(progress),
'updated_at': datetime.now().isoformat()
}
if message: update_data['message'] = message
if error_message: update_data['error_message'] = error_message
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
except Exception as e:
print(f"保存知识库内容错误: {e}")
return None
# --- 状态查询和数据库操作方法 ---
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""获取任务状态 - 与 async_llm_service 保持一致"""
"""获取任务状态"""
try:
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data:
@ -170,6 +229,20 @@ class AsyncKnowledgeBaseService:
print(f"Error getting task status: {e}")
return {'task_id': task_id, 'status': 'error', 'error_message': str(e)}
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
"""更新Redis中的任务状态"""
try:
update_data = {
'status': status,
'progress': str(progress),
'updated_at': datetime.now().isoformat()
}
if message: update_data['message'] = message
if error_message: update_data['error_message'] = error_message
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
except Exception as e:
print(f"Error updating task status in Redis: {e}")
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str):
"""保存任务到数据库"""
try:
@ -182,6 +255,17 @@ class AsyncKnowledgeBaseService:
print(f"Error saving task to database: {e}")
raise
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None):
"""更新数据库中的任务状态"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
cursor.execute(query, (status, progress, error_message, status, task_id))
connection.commit()
except Exception as e:
print(f"Error updating task in database: {e}")
def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]:
"""从数据库获取任务信息"""
try:
@ -198,4 +282,5 @@ class AsyncKnowledgeBaseService:
print(f"Error getting task from database: {e}")
return None
# 创建全局实例
async_kb_service = AsyncKnowledgeBaseService()

View File

@ -1,5 +1,5 @@
"""
异步LLM服务 - 处理会议总结生成的异步任务
异步会议服务 - 处理会议总结生成的异步任务
采用FastAPI BackgroundTasks模式
"""
import uuid
@ -12,8 +12,8 @@ from app.core.config import REDIS_CONFIG
from app.core.database import get_db_connection
from app.services.llm_service import LLMService
class AsyncLLMService:
"""异步LLM服务类 - 采用FastAPI BackgroundTasks模式"""
class AsyncMeetingService:
"""异步会议服务类 - 处理会议相关的异步任务"""
def __init__(self):
# 确保redis客户端自动解码响应代码更简洁
@ -53,7 +53,7 @@ class AsyncLLMService:
self.redis_client.hset(f"llm_task:{task_id}", mapping=task_data)
self.redis_client.expire(f"llm_task:{task_id}", 86400)
print(f"LLM summary task created: {task_id} for meeting: {meeting_id}")
print(f"Meeting summary task created: {task_id} for meeting: {meeting_id}")
return task_id
except Exception as e:
@ -64,7 +64,7 @@ class AsyncLLMService:
"""
处理单个异步任务的函数设计为由BackgroundTasks调用
"""
print(f"Background task started for LLM task: {task_id}")
print(f"Background task started for meeting summary task: {task_id}")
try:
# 从Redis获取任务数据
task_data = self.redis_client.hgetall(f"llm_task:{task_id}")
@ -80,13 +80,13 @@ class AsyncLLMService:
# 2. 获取会议转录内容
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取会议转录内容...")
transcript_text = self.llm_service._get_meeting_transcript(meeting_id)
transcript_text = self._get_meeting_transcript(meeting_id)
if not transcript_text:
raise Exception("无法获取会议转录内容")
# 3. 构建提示词
self._update_task_status_in_redis(task_id, 'processing', 40, message="准备AI提示词...")
full_prompt = self.llm_service._build_prompt(transcript_text, user_prompt)
full_prompt = self._build_prompt(transcript_text, user_prompt)
# 4. 调用LLM API
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在分析会议内容...")
@ -96,7 +96,7 @@ class AsyncLLMService:
# 5. 保存结果到主表
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存总结结果...")
self.llm_service._save_summary_to_db(meeting_id, summary_content, user_prompt)
self._save_summary_to_db(meeting_id, summary_content, user_prompt)
# 6. 任务完成
self._update_task_in_db(task_id, 'completed', 100, result=summary_content)
@ -110,6 +110,78 @@ class AsyncLLMService:
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
# --- 会议相关方法 ---
def _get_meeting_transcript(self, meeting_id: int) -> str:
"""从数据库获取会议转录内容"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = """
SELECT speaker_tag, start_time_ms, end_time_ms, text_content
FROM transcript_segments
WHERE meeting_id = %s
ORDER BY start_time_ms
"""
cursor.execute(query, (meeting_id,))
segments = cursor.fetchall()
if not segments:
return ""
# 组装转录文本
transcript_lines = []
for speaker_tag, start_time, end_time, text in segments:
# 将毫秒转换为分:秒格式
start_min = start_time // 60000
start_sec = (start_time % 60000) // 1000
transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}")
return "\n".join(transcript_lines)
except Exception as e:
print(f"获取会议转录内容错误: {e}")
return ""
def _build_prompt(self, transcript_text: str, user_prompt: str) -> str:
"""
构建完整的提示词
使用数据库中配置的MEETING_TASK提示词模板
"""
# 从数据库获取会议任务的提示词模板
system_prompt = self.llm_service.get_task_prompt('MEETING_TASK')
prompt = f"{system_prompt}\n\n"
if user_prompt:
prompt += f"用户额外要求:{user_prompt}\n\n"
prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:"
return prompt
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]:
"""保存总结到数据库 - 更新meetings表的summary、user_prompt和updated_at字段"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
# 更新meetings表的summary、user_prompt和updated_at字段
update_query = """
UPDATE meetings
SET summary = %s, user_prompt = %s, updated_at = NOW()
WHERE meeting_id = %s
"""
cursor.execute(update_query, (summary_content, user_prompt, meeting_id))
connection.commit()
print(f"成功保存会议总结到meetings表meeting_id: {meeting_id}")
return meeting_id
except Exception as e:
print(f"保存总结到数据库错误: {e}")
return None
# --- 状态查询和数据库操作方法 ---
def get_task_status(self, task_id: str) -> Dict[str, Any]:
@ -212,4 +284,4 @@ class AsyncLLMService:
return None
# 创建全局实例
async_llm_service = AsyncLLMService()
async_meeting_service = AsyncMeetingService()

View File

@ -7,6 +7,8 @@ from app.core.database import get_db_connection
class LLMService:
"""LLM服务 - 专注于大模型API调用和提示词管理"""
def __init__(self):
# 设置dashscope API key
dashscope.api_key = config_module.QWEN_API_KEY
@ -36,123 +38,47 @@ class LLMService:
"""动态获取top_p"""
return config_module.LLM_CONFIG["top_p"]
def generate_meeting_summary_stream(self, meeting_id: int, user_prompt: str = "") -> Generator[str, None, None]:
def get_task_prompt(self, task_name: str, cursor=None) -> str:
"""
流式生成会议总结
统一的提示词获取方法
Args:
meeting_id: 会议ID
user_prompt: 用户额外提示词
Yields:
str: 流式输出的内容片段
"""
try:
# 获取会议转录内容
transcript_text = self._get_meeting_transcript(meeting_id)
if not transcript_text:
yield "error: 无法获取会议转录内容"
return
# 构建完整提示词
full_prompt = self._build_prompt(transcript_text, user_prompt)
# 调用大模型API进行流式生成
full_content = ""
for chunk in self._call_llm_api_stream(full_prompt):
if chunk.startswith("error:"):
yield chunk
return
full_content += chunk
yield chunk
# 保存完整总结到数据库
if full_content:
self._save_summary_to_db(meeting_id, full_content, user_prompt)
except Exception as e:
print(f"流式生成会议总结错误: {e}")
yield f"error: {str(e)}"
def generate_meeting_summary(self, meeting_id: int, user_prompt: str = "") -> Optional[Dict]:
"""
生成会议总结非流式保持向后兼容
Args:
meeting_id: 会议ID
user_prompt: 用户额外提示词
task_name: 任务名称 'MEETING_TASK', 'KNOWLEDGE_TASK'
cursor: 数据库游标如果传入则使用否则创建新连接
Returns:
包含总结内容的字典如果失败返回None
str: 提示词内容如果未找到返回默认提示词
"""
try:
# 获取会议转录内容
transcript_text = self._get_meeting_transcript(meeting_id)
if not transcript_text:
return {"error": "无法获取会议转录内容"}
# 构建完整提示词
full_prompt = self._build_prompt(transcript_text, user_prompt)
# 调用大模型API
response = self._call_llm_api(full_prompt)
if response:
# 保存总结到数据库
summary_id = self._save_summary_to_db(meeting_id, response, user_prompt)
return {
"summary_id": summary_id,
"content": response,
"meeting_id": meeting_id
}
else:
return {"error": "大模型API调用失败"}
except Exception as e:
print(f"生成会议总结错误: {e}")
return {"error": str(e)}
def _get_meeting_transcript(self, meeting_id: int) -> str:
"""从数据库获取会议转录内容"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = """
SELECT speaker_tag, start_time_ms, end_time_ms, text_content
FROM transcript_segments
WHERE meeting_id = %s
ORDER BY start_time_ms
SELECT p.content
FROM prompt_config pc
JOIN prompts p ON pc.prompt_id = p.id
WHERE pc.task_name = %s
"""
cursor.execute(query, (meeting_id,))
segments = cursor.fetchall()
if not segments:
return ""
if cursor:
cursor.execute(query, (task_name,))
result = cursor.fetchone()
if result:
return result['content'] if isinstance(result, dict) else result[0]
else:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(query, (task_name,))
result = cursor.fetchone()
if result:
return result['content']
# 组装转录文本
transcript_lines = []
for speaker_tag, start_time, end_time, text in segments:
# 将毫秒转换为分:秒格式
start_min = start_time // 60000
start_sec = (start_time % 60000) // 1000
transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}")
# 返回默认提示词
return self._get_default_prompt(task_name)
return "\n".join(transcript_lines)
except Exception as e:
print(f"获取会议转录内容错误: {e}")
return ""
def _build_prompt(self, transcript_text: str, user_prompt: str) -> str:
"""构建完整的提示词"""
prompt = f"{self.system_prompt}\n\n"
if user_prompt:
prompt += f"用户额外要求:{user_prompt}\n\n"
prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:"
return prompt
def _get_default_prompt(self, task_name: str) -> str:
"""获取默认提示词"""
default_prompts = {
'MEETING_TASK': self.system_prompt, # 使用配置文件中的系统提示词
'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。",
}
return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。")
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
"""流式调用阿里Qwen3大模型API"""
@ -185,7 +111,7 @@ class LLMService:
yield f"error: {error_msg}"
def _call_llm_api(self, prompt: str) -> Optional[str]:
"""调用阿里Qwen3大模型API非流式,保持向后兼容"""
"""调用阿里Qwen3大模型API非流式"""
try:
response = dashscope.Generation.call(
model=self.model_name,
@ -205,95 +131,17 @@ class LLMService:
print(f"调用大模型API错误: {e}")
return None
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]:
"""保存总结到数据库 - 更新meetings表的summary字段"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
# 更新meetings表的summary字段
update_query = """
UPDATE meetings
SET summary = %s
WHERE meeting_id = %s
"""
cursor.execute(update_query, (summary_content, meeting_id))
connection.commit()
print(f"成功保存会议总结到meetings表meeting_id: {meeting_id}")
return meeting_id
except Exception as e:
print(f"保存总结到数据库错误: {e}")
return None
def get_meeting_summaries(self, meeting_id: int) -> List[Dict]:
"""获取会议的当前总结 - 从meetings表的summary字段获取"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = """
SELECT summary
FROM meetings
WHERE meeting_id = %s
"""
cursor.execute(query, (meeting_id,))
result = cursor.fetchone()
# 如果有总结内容返回一个包含当前总结的列表格式保持API一致性
if result and result[0]:
return [{
"id": meeting_id,
"content": result[0],
"user_prompt": "", # meetings表中没有user_prompt字段
"created_at": None # meetings表中没有单独的总结创建时间
}]
else:
return []
except Exception as e:
print(f"获取会议总结错误: {e}")
return []
def get_current_meeting_summary(self, meeting_id: int) -> Optional[str]:
"""获取会议当前的总结内容 - 从meetings表的summary字段获取"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = """
SELECT summary
FROM meetings
WHERE meeting_id = %s
"""
cursor.execute(query, (meeting_id,))
result = cursor.fetchone()
return result[0] if result and result[0] else None
except Exception as e:
print(f"获取会议当前总结错误: {e}")
return None
# 测试代码
if __name__ == '__main__':
# 测试LLM服务
test_meeting_id = 38
test_user_prompt = "请重点关注决策事项和待办任务"
print("--- 运行LLM服务测试 ---")
llm_service = LLMService()
# 生成总结
result = llm_service.generate_meeting_summary(test_meeting_id, test_user_prompt)
if result.get("error"):
print(f"生成总结失败: {result['error']}")
else:
print(f"总结生成成功ID: {result.get('summary_id')}")
print(f"总结内容: {result.get('content')[:200]}...")
# 测试获取任务提示词
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
print(f"会议任务提示词: {meeting_prompt[:100]}...")
# 获取历史总结
summaries = llm_service.get_meeting_summaries(test_meeting_id)
print(f"获取到 {len(summaries)} 个历史总结")
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
print("--- LLM服务测试完成 ---")

View File

@ -0,0 +1,218 @@
"""
声纹服务 - 处理用户声纹采集存储和验证
"""
import os
import json
import wave
from datetime import datetime
from typing import Optional, Dict
from pathlib import Path
from app.core.database import get_db_connection
import app.core.config as config_module
class VoiceprintService:
"""声纹服务类 - 同步处理声纹采集"""
def __init__(self):
self.voiceprint_dir = config_module.VOICEPRINT_DIR
self.voiceprint_config = config_module.VOICEPRINT_CONFIG
def get_user_voiceprint_status(self, user_id: int) -> Dict:
"""
获取用户声纹状态
Args:
user_id: 用户ID
Returns:
Dict: 声纹状态信息
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = """
SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at
FROM user_voiceprint
WHERE user_id = %s
"""
cursor.execute(query, (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
return {
"has_voiceprint": True,
"vp_id": voiceprint['vp_id'],
"file_path": voiceprint['file_path'],
"duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None,
"collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None
}
else:
return {
"has_voiceprint": False,
"vp_id": None,
"file_path": None,
"duration_seconds": None,
"collected_at": None
}
except Exception as e:
print(f"获取声纹状态错误: {e}")
raise e
def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict:
"""
保存声纹文件并提取特征向量
Args:
user_id: 用户ID
audio_file_path: 音频文件路径
file_size: 文件大小
Returns:
Dict: 保存结果
"""
try:
# 1. 获取音频时长
duration = self._get_audio_duration(audio_file_path)
# 2. 提取声纹向量调用FunASR
vector_data = self._extract_voiceprint_vector(audio_file_path)
# 3. 保存到数据库
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 检查用户是否已有声纹
cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,))
existing = cursor.fetchone()
# 计算相对路径
relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR))
if existing:
# 更新现有记录
update_query = """
UPDATE user_voiceprint
SET file_path = %s, file_size = %s, duration_seconds = %s,
vector_data = %s, updated_at = NOW()
WHERE user_id = %s
"""
cursor.execute(update_query, (
relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None,
user_id
))
vp_id = existing['vp_id']
else:
# 插入新记录
insert_query = """
INSERT INTO user_voiceprint
(user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at)
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
"""
cursor.execute(insert_query, (
user_id, relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None
))
vp_id = cursor.lastrowid
connection.commit()
return {
"vp_id": vp_id,
"user_id": user_id,
"file_path": relative_path,
"file_size": file_size,
"duration_seconds": duration,
"has_vector": vector_data is not None
}
except Exception as e:
print(f"保存声纹错误: {e}")
raise e
def delete_voiceprint(self, user_id: int) -> bool:
"""
删除用户声纹
Args:
user_id: 用户ID
Returns:
bool: 是否删除成功
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取文件路径
cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
# 构建完整文件路径
relative_path = voiceprint['file_path']
if relative_path.startswith('/'):
relative_path = relative_path.lstrip('/')
file_path = config_module.BASE_DIR / relative_path
# 删除数据库记录
cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,))
connection.commit()
# 删除文件
if file_path.exists():
os.remove(file_path)
return True
else:
return False
except Exception as e:
print(f"删除声纹错误: {e}")
raise e
def _get_audio_duration(self, audio_file_path: str) -> float:
"""
获取音频文件时长
Args:
audio_file_path: 音频文件路径
Returns:
float: 时长
"""
try:
with wave.open(audio_file_path, 'rb') as wav_file:
frames = wav_file.getnframes()
rate = wav_file.getframerate()
duration = frames / float(rate)
return round(duration, 2)
except Exception as e:
print(f"获取音频时长错误: {e}")
return 10.0 # 默认返回10秒
def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]:
"""
提取声纹特征向量调用FunASR
Args:
audio_file_path: 音频文件路径
Returns:
Optional[list]: 声纹向量192失败返回None
"""
# TODO: 集成FunASR的说话人识别模型
# 使用 speech_campplus_sv_zh-cn_16k-common 模型
# 返回192维的embedding向量
print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}")
# 暂时返回None等待FunASR集成
# 集成后应该返回类似: [0.123, -0.456, 0.789, ...]
return None
# 创建全局实例
voiceprint_service = VoiceprintService()

View File

@ -1,7 +1,7 @@
{
"model_name": "qwen-plus",
"system_prompt": "你是一个专业的会议记录分析助手。请根据提供的会议转录内容,生成简洁明了的会议总结。\n\n总结包括五个部分名称严格一致生成为MD二级目录\n1. 会议概述 - 简要说明会议的主要目的和背景(生成MD引用)\n2. 主要讨论点 - 列出会议中讨论的重要话题和内容\n3. 决策事项 - 明确记录会议中做出的决定和结论\n4. 待办事项 - 列出需要后续跟进的任务和责任人\n5. 关键信息 - 其他重要的信息点\n\n输出要求\n- 保持客观中性,不添加个人观点\n- 使用简洁、准确的中文表达\n- 按重要性排序各项内容\n- 如果某个部分没有相关内容,可以说明\"无相关内容\"\n- 总字数控制在500字以内",
"DEFAULT_RESET_PASSWORD": "111111",
"MAX_FILE_SIZE": 209715200,
"DEFAULT_RESET_PASSWORD": "123456",
"MAX_FILE_SIZE": 208666624,
"MAX_IMAGE_SIZE": 10485760
}

View File

@ -2,7 +2,7 @@ import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads, voiceprint
from app.core.config import UPLOAD_DIR, API_CONFIG
from app.api.endpoints.admin import load_system_config
import os
@ -39,6 +39,7 @@ app.include_router(tasks.router, prefix="/api", tags=["Tasks"])
app.include_router(prompts.router, prefix="/api", tags=["Prompts"])
app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"])
app.include_router(client_downloads.router, prefix="/api/clients", tags=["ClientDownloads"])
app.include_router(voiceprint.router, prefix="/api", tags=["Voiceprint"])
@app.get("/")
def read_root():

View File

@ -6,10 +6,10 @@ import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.services.async_llm_service import AsyncLLMService
from app.services.async_meeting_service import AsyncMeetingService
# 创建服务实例
service = AsyncLLMService()
service = AsyncMeetingService()
# 创建测试任务
meeting_id = 38

View File

@ -8,10 +8,10 @@ import time
import threading
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.services.async_llm_service import AsyncLLMService
from app.services.async_meeting_service import AsyncMeetingService
# 创建服务实例
service = AsyncLLMService()
service = AsyncMeetingService()
# 直接调用处理任务方法测试
print("测试直接调用_process_tasks方法...")

View File

@ -0,0 +1,141 @@
"""
声纹采集API测试脚本
使用方法
1. 确保后端服务正在运行
2. 修改 USER_ID TOKEN 为实际值
3. 准备一个10秒的WAV音频文件
4. 运行: python test_voiceprint_api.py
"""
import requests
import json
# 配置
BASE_URL = "http://localhost:8000/api"
USER_ID = 1 # 修改为实际用户ID
TOKEN = "" # 登录后获取的token
# 请求头
headers = {
"Authorization": f"Bearer {TOKEN}",
"Content-Type": "application/json"
}
def test_get_template():
"""测试获取朗读模板"""
print("\n=== 测试1: 获取朗读模板 ===")
url = f"{BASE_URL}/voiceprint/template"
response = requests.get(url, headers=headers)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def test_get_status(user_id):
"""测试获取声纹状态"""
print(f"\n=== 测试2: 获取用户 {user_id} 的声纹状态 ===")
url = f"{BASE_URL}/voiceprint/{user_id}"
response = requests.get(url, headers=headers)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def test_upload_voiceprint(user_id, audio_file_path):
"""测试上传声纹"""
print(f"\n=== 测试3: 上传声纹音频 ===")
url = f"{BASE_URL}/voiceprint/{user_id}"
# 移除Content-Type让requests自动设置multipart/form-data
upload_headers = {
"Authorization": f"Bearer {TOKEN}"
}
with open(audio_file_path, 'rb') as f:
files = {'audio_file': (audio_file_path.split('/')[-1], f, 'audio/wav')}
response = requests.post(url, headers=upload_headers, files=files)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def test_delete_voiceprint(user_id):
"""测试删除声纹"""
print(f"\n=== 测试4: 删除用户 {user_id} 的声纹 ===")
url = f"{BASE_URL}/voiceprint/{user_id}"
response = requests.delete(url, headers=headers)
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
return response.json()
def login(username, password):
"""登录获取token"""
print("\n=== 登录获取Token ===")
url = f"{BASE_URL}/auth/login"
data = {
"username": username,
"password": password
}
response = requests.post(url, json=data)
if response.status_code == 200:
result = response.json()
if result.get('code') == '200':
token = result['data']['token']
print(f"登录成功Token: {token[:20]}...")
return token
else:
print(f"登录失败: {result.get('message')}")
return None
else:
print(f"请求失败,状态码: {response.status_code}")
return None
if __name__ == "__main__":
print("=" * 60)
print("声纹采集API测试脚本")
print("=" * 60)
# 步骤1: 登录如果没有token
if not TOKEN:
print("\n请先登录获取Token...")
username = input("用户名: ")
password = input("密码: ")
TOKEN = login(username, password)
if TOKEN:
headers["Authorization"] = f"Bearer {TOKEN}"
else:
print("登录失败,退出测试")
exit(1)
# 步骤2: 测试获取朗读模板
test_get_template()
# 步骤3: 测试获取声纹状态
test_get_status(USER_ID)
# 步骤4: 测试上传声纹(需要准备音频文件)
audio_file = input("\n请输入WAV音频文件路径 (回车跳过上传测试): ")
if audio_file.strip():
test_upload_voiceprint(USER_ID, audio_file.strip())
# 上传后再次查看状态
print("\n=== 上传后再次查看状态 ===")
test_get_status(USER_ID)
# 步骤5: 测试删除声纹
confirm = input("\n是否测试删除声纹? (yes/no): ")
if confirm.lower() == 'yes':
test_delete_voiceprint(USER_ID)
# 删除后再次查看状态
print("\n=== 删除后再次查看状态 ===")
test_get_status(USER_ID)
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)