1.0.3
parent
976ea854b6
commit
2f36474f4d
|
|
@ -9,29 +9,25 @@ import datetime
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
|
||||
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
|
||||
"""
|
||||
处理标签:查询已存在的标签,如果提供了 creator_id 则创建不存在的标签
|
||||
"""
|
||||
if not tag_string:
|
||||
return []
|
||||
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
|
||||
if not tag_names:
|
||||
return []
|
||||
|
||||
placeholders = ','.join(['%s'] * len(tag_names))
|
||||
select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})"
|
||||
cursor.execute(select_query, tuple(tag_names))
|
||||
|
||||
# 如果提供了 creator_id,则创建不存在的标签
|
||||
if creator_id:
|
||||
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
|
||||
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
|
||||
|
||||
# 查询所有标签信息
|
||||
format_strings = ', '.join(['%s'] * len(tag_names))
|
||||
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
|
||||
tags_data = cursor.fetchall()
|
||||
|
||||
existing_tags = {tag['name']: tag for tag in tags_data}
|
||||
new_tags = [name for name in tag_names if name not in existing_tags]
|
||||
|
||||
if new_tags:
|
||||
insert_query = "INSERT INTO tags (name) VALUES (%s)"
|
||||
cursor.executemany(insert_query, [(name,) for name in new_tags])
|
||||
|
||||
# Re-fetch all tags to get their IDs and default colors
|
||||
cursor.execute(select_query, tuple(tag_names))
|
||||
tags_data = cursor.fetchall()
|
||||
|
||||
return [Tag(**tag) for tag in tags_data]
|
||||
|
||||
|
||||
|
|
@ -83,7 +79,8 @@ def get_knowledge_bases(
|
|||
|
||||
kb_list = []
|
||||
for kb_data in kbs_data:
|
||||
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
|
||||
# 列表页不需要处理 tags,直接使用字符串
|
||||
# kb_data['tags'] 保持原样(逗号分隔的标签名称字符串)
|
||||
# Count source meetings - filter empty strings
|
||||
if kb_data.get('source_meeting_ids'):
|
||||
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
|
||||
|
|
@ -122,7 +119,7 @@ def create_knowledge_base(
|
|||
request.is_shared,
|
||||
request.source_meeting_ids,
|
||||
request.user_prompt,
|
||||
request.tags,
|
||||
request.tags, # 创建时 tags 应该为 None 或空字符串
|
||||
now,
|
||||
now
|
||||
))
|
||||
|
|
@ -136,7 +133,7 @@ def create_knowledge_base(
|
|||
source_meeting_ids=request.source_meeting_ids,
|
||||
cursor=cursor
|
||||
)
|
||||
|
||||
|
||||
connection.commit()
|
||||
|
||||
# Add the background task to process the knowledge base generation
|
||||
|
|
@ -171,7 +168,8 @@ def get_knowledge_base_detail(
|
|||
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Process tags
|
||||
# Process tags - 获取标签的完整信息(包括颜色)
|
||||
# 详情页不需要创建新标签,所以不传 creator_id
|
||||
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
|
||||
|
||||
# Get source meetings details
|
||||
|
|
@ -220,6 +218,10 @@ def update_knowledge_base(
|
|||
if kb['creator_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
|
||||
|
||||
# 使用 _process_tags 处理标签(会自动创建新标签)
|
||||
if request.tags:
|
||||
_process_tags(cursor, request.tags, current_user['user_id'])
|
||||
|
||||
# Update the knowledge base
|
||||
now = datetime.datetime.utcnow()
|
||||
update_query = """
|
||||
|
|
|
|||
|
|
@ -23,14 +23,22 @@ transcription_service = AsyncTranscriptionService()
|
|||
class GenerateSummaryRequest(BaseModel):
|
||||
user_prompt: Optional[str] = ""
|
||||
|
||||
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
|
||||
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
|
||||
"""
|
||||
处理标签:查询已存在的标签,如果提供了 creator_id 则创建不存在的标签
|
||||
"""
|
||||
if not tag_string:
|
||||
return []
|
||||
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
|
||||
if not tag_names:
|
||||
return []
|
||||
insert_ignore_query = "INSERT IGNORE INTO tags (name) VALUES (%s)"
|
||||
cursor.executemany(insert_ignore_query, [(name,) for name in tag_names])
|
||||
|
||||
# 如果提供了 creator_id,则创建不存在的标签
|
||||
if creator_id:
|
||||
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
|
||||
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
|
||||
|
||||
# 查询所有标签信息
|
||||
format_strings = ', '.join(['%s'] * len(tag_names))
|
||||
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
|
||||
tags_data = cursor.fetchall()
|
||||
|
|
@ -125,10 +133,9 @@ def get_meeting_transcript(meeting_id: int, current_user: dict = Depends(get_cur
|
|||
def create_meeting(meeting_request: CreateMeetingRequest, current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
# 使用 _process_tags 来处理标签创建
|
||||
if meeting_request.tags:
|
||||
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()]
|
||||
if tag_names:
|
||||
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
|
||||
_process_tags(cursor, meeting_request.tags, current_user['user_id'])
|
||||
meeting_query = 'INSERT INTO meetings (user_id, title, meeting_time, summary, tags, created_at) VALUES (%s, %s, %s, %s, %s, %s)'
|
||||
cursor.execute(meeting_query, (meeting_request.user_id, meeting_request.title, meeting_request.meeting_time, None, meeting_request.tags, datetime.now().isoformat()))
|
||||
meeting_id = cursor.lastrowid
|
||||
|
|
@ -147,10 +154,9 @@ def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest, curre
|
|||
return create_api_response(code="404", message="Meeting not found")
|
||||
if meeting['user_id'] != current_user['user_id']:
|
||||
return create_api_response(code="403", message="Permission denied")
|
||||
# 使用 _process_tags 来处理标签创建
|
||||
if meeting_request.tags:
|
||||
tag_names = [name.strip() for name in meeting_request.tags.split(',') if name.strip()]
|
||||
if tag_names:
|
||||
cursor.executemany("INSERT IGNORE INTO tags (name) VALUES (%s)", [(name,) for name in tag_names])
|
||||
_process_tags(cursor, meeting_request.tags, current_user['user_id'])
|
||||
update_query = 'UPDATE meetings SET title = %s, meeting_time = %s, summary = %s, tags = %s WHERE meeting_id = %s'
|
||||
cursor.execute(update_query, (meeting_request.title, meeting_request.meeting_time, meeting_request.summary, meeting_request.tags, meeting_id))
|
||||
cursor.execute("DELETE FROM attendees WHERE meeting_id = %s", (meeting_id,))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,131 @@
|
|||
"""
|
||||
声纹采集API接口
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
|
||||
from app.models.models import VoiceprintStatus, VoiceprintTemplate
|
||||
from app.core.auth import get_current_user
|
||||
from app.core.response import create_api_response
|
||||
from app.services.voiceprint_service import voiceprint_service
|
||||
import app.core.config as config_module
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/voiceprint/template", response_model=None)
|
||||
def get_voiceprint_template(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
获取声纹采集朗读模板配置
|
||||
|
||||
权限:需要登录
|
||||
"""
|
||||
try:
|
||||
template_data = VoiceprintTemplate(**config_module.VOICEPRINT_CONFIG)
|
||||
return create_api_response(code="200", message="获取朗读模板成功", data=template_data.dict())
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"获取朗读模板失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/voiceprint/{user_id}", response_model=None)
|
||||
def get_voiceprint_status(user_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
获取用户声纹采集状态
|
||||
|
||||
权限:用户只能查询自己的声纹状态,管理员可查询所有
|
||||
"""
|
||||
# 权限检查:只能查询自己的声纹,或者是管理员
|
||||
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
|
||||
return create_api_response(code="403", message="无权限查询其他用户的声纹状态")
|
||||
|
||||
try:
|
||||
status_data = voiceprint_service.get_user_voiceprint_status(user_id)
|
||||
return create_api_response(code="200", message="获取声纹状态成功", data=status_data)
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"获取声纹状态失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/voiceprint/{user_id}", response_model=None)
|
||||
async def upload_voiceprint(
|
||||
user_id: int,
|
||||
audio_file: UploadFile = File(...),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
上传声纹音频文件(同步处理)
|
||||
|
||||
权限:用户只能上传自己的声纹,管理员可操作所有
|
||||
"""
|
||||
# 权限检查
|
||||
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
|
||||
return create_api_response(code="403", message="无权限上传其他用户的声纹")
|
||||
|
||||
# 检查文件格式
|
||||
file_ext = Path(audio_file.filename).suffix.lower()
|
||||
if file_ext not in config_module.ALLOWED_VOICEPRINT_EXTENSIONS:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message=f"不支持的文件格式,仅支持: {', '.join(config_module.ALLOWED_VOICEPRINT_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
# 检查文件大小
|
||||
max_size = config_module.VOICEPRINT_CONFIG.get('max_file_size', 5242880) # 默认5MB
|
||||
content = await audio_file.read()
|
||||
file_size = len(content)
|
||||
|
||||
if file_size > max_size:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message=f"文件过大,最大允许 {max_size / 1024 / 1024:.1f}MB"
|
||||
)
|
||||
|
||||
try:
|
||||
# 确保用户目录存在
|
||||
user_voiceprint_dir = config_module.VOICEPRINT_DIR / str(user_id)
|
||||
user_voiceprint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成文件名:时间戳.wav
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{timestamp}.wav"
|
||||
file_path = user_voiceprint_dir / filename
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 调用服务处理声纹(提取特征向量,保存到数据库)
|
||||
result = voiceprint_service.save_voiceprint(user_id, str(file_path), file_size)
|
||||
|
||||
return create_api_response(code="200", message="声纹采集成功", data=result)
|
||||
|
||||
except Exception as e:
|
||||
# 如果出错,删除已上传的文件
|
||||
if 'file_path' in locals() and Path(file_path).exists():
|
||||
Path(file_path).unlink()
|
||||
|
||||
return create_api_response(code="500", message=f"声纹采集失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/voiceprint/{user_id}", response_model=None)
|
||||
def delete_voiceprint(user_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
删除用户声纹数据,允许重新采集
|
||||
|
||||
权限:用户只能删除自己的声纹,管理员可操作所有
|
||||
"""
|
||||
# 权限检查
|
||||
if current_user['user_id'] != user_id and current_user['role_id'] != 1:
|
||||
return create_api_response(code="403", message="无权限删除其他用户的声纹")
|
||||
|
||||
try:
|
||||
success = voiceprint_service.delete_voiceprint(user_id)
|
||||
|
||||
if success:
|
||||
return create_api_response(code="200", message="声纹删除成功")
|
||||
else:
|
||||
return create_api_response(code="404", message="未找到该用户的声纹数据")
|
||||
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"删除声纹失败: {str(e)}")
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# 基础路径配置
|
||||
|
|
@ -6,10 +7,12 @@ BASE_DIR = Path(__file__).parent.parent.parent
|
|||
UPLOAD_DIR = BASE_DIR / "uploads"
|
||||
AUDIO_DIR = UPLOAD_DIR / "audio"
|
||||
MARKDOWN_DIR = UPLOAD_DIR / "markdown"
|
||||
VOICEPRINT_DIR = UPLOAD_DIR / "voiceprint"
|
||||
|
||||
# 文件上传配置
|
||||
ALLOWED_EXTENSIONS = {".mp3", ".wav", ".m4a", ".mpeg", ".mp4"}
|
||||
ALLOWED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
ALLOWED_VOICEPRINT_EXTENSIONS = {".wav"}
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
|
|
@ -17,6 +20,7 @@ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
|
|||
UPLOAD_DIR.mkdir(exist_ok=True)
|
||||
AUDIO_DIR.mkdir(exist_ok=True)
|
||||
MARKDOWN_DIR.mkdir(exist_ok=True)
|
||||
VOICEPRINT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_CONFIG = {
|
||||
|
|
@ -82,3 +86,12 @@ LLM_CONFIG = {
|
|||
|
||||
# 密码重置配置
|
||||
DEFAULT_RESET_PASSWORD = os.getenv('DEFAULT_RESET_PASSWORD', '111111')
|
||||
|
||||
# 加载系统配置文件
|
||||
# 默认声纹配置
|
||||
VOICEPRINT_CONFIG = {
|
||||
"template_text": "我正在进行声纹采集,这段语音将用于身份识别和验证。\n\n声纹技术能够准确识别每个人独特的声音特征。",
|
||||
"duration_seconds": 12,
|
||||
"sample_rate": 16000,
|
||||
"channels": 1
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class KnowledgeBase(BaseModel):
|
|||
is_shared: bool
|
||||
source_meeting_ids: Optional[str] = None
|
||||
user_prompt: Optional[str] = None
|
||||
tags: Optional[List[Tag]] = []
|
||||
tags: Union[Optional[str], Optional[List[Tag]]] = None # 支持字符串或Tag列表
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
source_meeting_count: Optional[int] = 0
|
||||
|
|
@ -204,3 +204,26 @@ class UpdateClientDownloadRequest(BaseModel):
|
|||
class ClientDownloadListResponse(BaseModel):
|
||||
clients: List[ClientDownload]
|
||||
total: int
|
||||
|
||||
# 声纹采集相关模型
|
||||
class VoiceprintInfo(BaseModel):
|
||||
vp_id: int
|
||||
user_id: int
|
||||
file_path: str
|
||||
file_size: Optional[int] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
collected_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
class VoiceprintStatus(BaseModel):
|
||||
has_voiceprint: bool
|
||||
vp_id: Optional[int] = None
|
||||
file_path: Optional[str] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
collected_at: Optional[datetime.datetime] = None
|
||||
|
||||
class VoiceprintTemplate(BaseModel):
|
||||
template_text: str
|
||||
duration_seconds: int
|
||||
sample_rate: int
|
||||
channels: int
|
||||
|
|
|
|||
|
|
@ -0,0 +1,218 @@
|
|||
"""
|
||||
声纹服务 - 处理用户声纹采集、存储和验证
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import wave
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.database import get_db_connection
|
||||
import app.core.config as config_module
|
||||
|
||||
|
||||
class VoiceprintService:
|
||||
"""声纹服务类 - 同步处理声纹采集"""
|
||||
|
||||
def __init__(self):
|
||||
self.voiceprint_dir = config_module.VOICEPRINT_DIR
|
||||
self.voiceprint_config = config_module.VOICEPRINT_CONFIG
|
||||
|
||||
def get_user_voiceprint_status(self, user_id: int) -> Dict:
|
||||
"""
|
||||
获取用户声纹状态
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
Dict: 声纹状态信息
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
query = """
|
||||
SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at
|
||||
FROM user_voiceprint
|
||||
WHERE user_id = %s
|
||||
"""
|
||||
cursor.execute(query, (user_id,))
|
||||
voiceprint = cursor.fetchone()
|
||||
|
||||
if voiceprint:
|
||||
return {
|
||||
"has_voiceprint": True,
|
||||
"vp_id": voiceprint['vp_id'],
|
||||
"file_path": voiceprint['file_path'],
|
||||
"duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None,
|
||||
"collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"has_voiceprint": False,
|
||||
"vp_id": None,
|
||||
"file_path": None,
|
||||
"duration_seconds": None,
|
||||
"collected_at": None
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"获取声纹状态错误: {e}")
|
||||
raise e
|
||||
|
||||
def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict:
|
||||
"""
|
||||
保存声纹文件并提取特征向量
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
audio_file_path: 音频文件路径
|
||||
file_size: 文件大小
|
||||
|
||||
Returns:
|
||||
Dict: 保存结果
|
||||
"""
|
||||
try:
|
||||
# 1. 获取音频时长
|
||||
duration = self._get_audio_duration(audio_file_path)
|
||||
|
||||
# 2. 提取声纹向量(调用FunASR)
|
||||
vector_data = self._extract_voiceprint_vector(audio_file_path)
|
||||
|
||||
# 3. 保存到数据库
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 检查用户是否已有声纹
|
||||
cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,))
|
||||
existing = cursor.fetchone()
|
||||
|
||||
# 计算相对路径
|
||||
relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR))
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
update_query = """
|
||||
UPDATE user_voiceprint
|
||||
SET file_path = %s, file_size = %s, duration_seconds = %s,
|
||||
vector_data = %s, updated_at = NOW()
|
||||
WHERE user_id = %s
|
||||
"""
|
||||
cursor.execute(update_query, (
|
||||
relative_path, file_size, duration,
|
||||
json.dumps(vector_data) if vector_data else None,
|
||||
user_id
|
||||
))
|
||||
vp_id = existing['vp_id']
|
||||
else:
|
||||
# 插入新记录
|
||||
insert_query = """
|
||||
INSERT INTO user_voiceprint
|
||||
(user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at)
|
||||
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
user_id, relative_path, file_size, duration,
|
||||
json.dumps(vector_data) if vector_data else None
|
||||
))
|
||||
vp_id = cursor.lastrowid
|
||||
|
||||
connection.commit()
|
||||
|
||||
return {
|
||||
"vp_id": vp_id,
|
||||
"user_id": user_id,
|
||||
"file_path": relative_path,
|
||||
"file_size": file_size,
|
||||
"duration_seconds": duration,
|
||||
"has_vector": vector_data is not None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存声纹错误: {e}")
|
||||
raise e
|
||||
|
||||
def delete_voiceprint(self, user_id: int) -> bool:
|
||||
"""
|
||||
删除用户声纹
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 获取文件路径
|
||||
cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,))
|
||||
voiceprint = cursor.fetchone()
|
||||
|
||||
if voiceprint:
|
||||
# 构建完整文件路径
|
||||
relative_path = voiceprint['file_path']
|
||||
if relative_path.startswith('/'):
|
||||
relative_path = relative_path.lstrip('/')
|
||||
file_path = config_module.BASE_DIR / relative_path
|
||||
|
||||
# 删除数据库记录
|
||||
cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,))
|
||||
connection.commit()
|
||||
|
||||
# 删除文件
|
||||
if file_path.exists():
|
||||
os.remove(file_path)
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"删除声纹错误: {e}")
|
||||
raise e
|
||||
|
||||
def _get_audio_duration(self, audio_file_path: str) -> float:
|
||||
"""
|
||||
获取音频文件时长
|
||||
|
||||
Args:
|
||||
audio_file_path: 音频文件路径
|
||||
|
||||
Returns:
|
||||
float: 时长(秒)
|
||||
"""
|
||||
try:
|
||||
with wave.open(audio_file_path, 'rb') as wav_file:
|
||||
frames = wav_file.getnframes()
|
||||
rate = wav_file.getframerate()
|
||||
duration = frames / float(rate)
|
||||
return round(duration, 2)
|
||||
except Exception as e:
|
||||
print(f"获取音频时长错误: {e}")
|
||||
return 10.0 # 默认返回10秒
|
||||
|
||||
def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]:
|
||||
"""
|
||||
提取声纹特征向量(调用FunASR)
|
||||
|
||||
Args:
|
||||
audio_file_path: 音频文件路径
|
||||
|
||||
Returns:
|
||||
Optional[list]: 声纹向量(192维),失败返回None
|
||||
"""
|
||||
# TODO: 集成FunASR的说话人识别模型
|
||||
# 使用 speech_campplus_sv_zh-cn_16k-common 模型
|
||||
# 返回192维的embedding向量
|
||||
|
||||
print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}")
|
||||
|
||||
# 暂时返回None,等待FunASR集成
|
||||
# 集成后应该返回类似: [0.123, -0.456, 0.789, ...]
|
||||
return None
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
voiceprint_service = VoiceprintService()
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"model_name": "qwen-plus",
|
||||
"system_prompt": "你是一个专业的会议记录分析助手。请根据提供的会议转录内容,生成简洁明了的会议总结。\n\n总结包括五个部分(名称严格一致,生成为MD二级目录):\n1. 会议概述 - 简要说明会议的主要目的和背景(生成MD引用)\n2. 主要讨论点 - 列出会议中讨论的重要话题和内容\n3. 决策事项 - 明确记录会议中做出的决定和结论\n4. 待办事项 - 列出需要后续跟进的任务和责任人\n5. 关键信息 - 其他重要的信息点\n\n输出要求:\n- 保持客观中性,不添加个人观点\n- 使用简洁、准确的中文表达\n- 按重要性排序各项内容\n- 如果某个部分没有相关内容,可以说明\"无相关内容\"\n- 总字数控制在500字以内",
|
||||
"DEFAULT_RESET_PASSWORD": "111111",
|
||||
"MAX_FILE_SIZE": 209715200,
|
||||
"DEFAULT_RESET_PASSWORD": "123456",
|
||||
"MAX_FILE_SIZE": 208666624,
|
||||
"MAX_IMAGE_SIZE": 10485760
|
||||
}
|
||||
3
main.py
3
main.py
|
|
@ -2,7 +2,7 @@ import uvicorn
|
|||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads
|
||||
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads, voiceprint
|
||||
from app.core.config import UPLOAD_DIR, API_CONFIG
|
||||
from app.api.endpoints.admin import load_system_config
|
||||
import os
|
||||
|
|
@ -39,6 +39,7 @@ app.include_router(tasks.router, prefix="/api", tags=["Tasks"])
|
|||
app.include_router(prompts.router, prefix="/api", tags=["Prompts"])
|
||||
app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"])
|
||||
app.include_router(client_downloads.router, prefix="/api/clients", tags=["ClientDownloads"])
|
||||
app.include_router(voiceprint.router, prefix="/api", tags=["Voiceprint"])
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
-- 为meetings表添加updated_at和user_prompt字段
|
||||
-- 执行日期: 2025-10-28
|
||||
|
||||
-- 添加updated_at字段
|
||||
ALTER TABLE meetings
|
||||
ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
AFTER created_at;
|
||||
|
||||
-- 添加user_prompt字段
|
||||
ALTER TABLE meetings
|
||||
ADD COLUMN user_prompt TEXT
|
||||
AFTER summary;
|
||||
|
||||
-- 为现有记录设置updated_at为created_at的值
|
||||
UPDATE meetings
|
||||
SET updated_at = created_at
|
||||
WHERE updated_at IS NULL;
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
"""
|
||||
声纹采集API测试脚本
|
||||
|
||||
使用方法:
|
||||
1. 确保后端服务正在运行
|
||||
2. 修改 USER_ID 和 TOKEN 为实际值
|
||||
3. 准备一个10秒的WAV音频文件
|
||||
4. 运行: python test_voiceprint_api.py
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
# 配置
|
||||
BASE_URL = "http://localhost:8000/api"
|
||||
USER_ID = 1 # 修改为实际用户ID
|
||||
TOKEN = "" # 登录后获取的token
|
||||
|
||||
# 请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {TOKEN}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
def test_get_template():
|
||||
"""测试获取朗读模板"""
|
||||
print("\n=== 测试1: 获取朗读模板 ===")
|
||||
url = f"{BASE_URL}/voiceprint/template"
|
||||
response = requests.get(url, headers=headers)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_get_status(user_id):
|
||||
"""测试获取声纹状态"""
|
||||
print(f"\n=== 测试2: 获取用户 {user_id} 的声纹状态 ===")
|
||||
url = f"{BASE_URL}/voiceprint/{user_id}"
|
||||
response = requests.get(url, headers=headers)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_upload_voiceprint(user_id, audio_file_path):
|
||||
"""测试上传声纹"""
|
||||
print(f"\n=== 测试3: 上传声纹音频 ===")
|
||||
url = f"{BASE_URL}/voiceprint/{user_id}"
|
||||
|
||||
# 移除Content-Type,让requests自动设置multipart/form-data
|
||||
upload_headers = {
|
||||
"Authorization": f"Bearer {TOKEN}"
|
||||
}
|
||||
|
||||
with open(audio_file_path, 'rb') as f:
|
||||
files = {'audio_file': (audio_file_path.split('/')[-1], f, 'audio/wav')}
|
||||
response = requests.post(url, headers=upload_headers, files=files)
|
||||
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_delete_voiceprint(user_id):
|
||||
"""测试删除声纹"""
|
||||
print(f"\n=== 测试4: 删除用户 {user_id} 的声纹 ===")
|
||||
url = f"{BASE_URL}/voiceprint/{user_id}"
|
||||
response = requests.delete(url, headers=headers)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
|
||||
return response.json()
|
||||
|
||||
|
||||
def login(username, password):
|
||||
"""登录获取token"""
|
||||
print("\n=== 登录获取Token ===")
|
||||
url = f"{BASE_URL}/auth/login"
|
||||
data = {
|
||||
"username": username,
|
||||
"password": password
|
||||
}
|
||||
response = requests.post(url, json=data)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get('code') == '200':
|
||||
token = result['data']['token']
|
||||
print(f"登录成功,Token: {token[:20]}...")
|
||||
return token
|
||||
else:
|
||||
print(f"登录失败: {result.get('message')}")
|
||||
return None
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("声纹采集API测试脚本")
|
||||
print("=" * 60)
|
||||
|
||||
# 步骤1: 登录(如果没有token)
|
||||
if not TOKEN:
|
||||
print("\n请先登录获取Token...")
|
||||
username = input("用户名: ")
|
||||
password = input("密码: ")
|
||||
TOKEN = login(username, password)
|
||||
if TOKEN:
|
||||
headers["Authorization"] = f"Bearer {TOKEN}"
|
||||
else:
|
||||
print("登录失败,退出测试")
|
||||
exit(1)
|
||||
|
||||
# 步骤2: 测试获取朗读模板
|
||||
test_get_template()
|
||||
|
||||
# 步骤3: 测试获取声纹状态
|
||||
test_get_status(USER_ID)
|
||||
|
||||
# 步骤4: 测试上传声纹(需要准备音频文件)
|
||||
audio_file = input("\n请输入WAV音频文件路径 (回车跳过上传测试): ")
|
||||
if audio_file.strip():
|
||||
test_upload_voiceprint(USER_ID, audio_file.strip())
|
||||
|
||||
# 上传后再次查看状态
|
||||
print("\n=== 上传后再次查看状态 ===")
|
||||
test_get_status(USER_ID)
|
||||
|
||||
# 步骤5: 测试删除声纹
|
||||
confirm = input("\n是否测试删除声纹? (yes/no): ")
|
||||
if confirm.lower() == 'yes':
|
||||
test_delete_voiceprint(USER_ID)
|
||||
|
||||
# 删除后再次查看状态
|
||||
print("\n=== 删除后再次查看状态 ===")
|
||||
test_get_status(USER_ID)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成")
|
||||
print("=" * 60)
|
||||
Loading…
Reference in New Issue