219 lines
7.4 KiB
Python
219 lines
7.4 KiB
Python
"""
|
||
声纹服务 - 处理用户声纹采集、存储和验证
|
||
"""
|
||
import os
|
||
import json
|
||
import wave
|
||
from datetime import datetime
|
||
from typing import Optional, Dict
|
||
from pathlib import Path
|
||
|
||
from app.core.database import get_db_connection
|
||
import app.core.config as config_module
|
||
|
||
|
||
class VoiceprintService:
|
||
"""声纹服务类 - 同步处理声纹采集"""
|
||
|
||
def __init__(self):
|
||
self.voiceprint_dir = config_module.VOICEPRINT_DIR
|
||
self.voiceprint_config = config_module.VOICEPRINT_CONFIG
|
||
|
||
def get_user_voiceprint_status(self, user_id: int) -> Dict:
|
||
"""
|
||
获取用户声纹状态
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
Dict: 声纹状态信息
|
||
"""
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
query = """
|
||
SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at
|
||
FROM user_voiceprint
|
||
WHERE user_id = %s
|
||
"""
|
||
cursor.execute(query, (user_id,))
|
||
voiceprint = cursor.fetchone()
|
||
|
||
if voiceprint:
|
||
return {
|
||
"has_voiceprint": True,
|
||
"vp_id": voiceprint['vp_id'],
|
||
"file_path": voiceprint['file_path'],
|
||
"duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None,
|
||
"collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None
|
||
}
|
||
else:
|
||
return {
|
||
"has_voiceprint": False,
|
||
"vp_id": None,
|
||
"file_path": None,
|
||
"duration_seconds": None,
|
||
"collected_at": None
|
||
}
|
||
except Exception as e:
|
||
print(f"获取声纹状态错误: {e}")
|
||
raise e
|
||
|
||
def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict:
|
||
"""
|
||
保存声纹文件并提取特征向量
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
audio_file_path: 音频文件路径
|
||
file_size: 文件大小
|
||
|
||
Returns:
|
||
Dict: 保存结果
|
||
"""
|
||
try:
|
||
# 1. 获取音频时长
|
||
duration = self._get_audio_duration(audio_file_path)
|
||
|
||
# 2. 提取声纹向量(调用FunASR)
|
||
vector_data = self._extract_voiceprint_vector(audio_file_path)
|
||
|
||
# 3. 保存到数据库
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# 检查用户是否已有声纹
|
||
cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,))
|
||
existing = cursor.fetchone()
|
||
|
||
# 计算相对路径
|
||
relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR))
|
||
|
||
if existing:
|
||
# 更新现有记录
|
||
update_query = """
|
||
UPDATE user_voiceprint
|
||
SET file_path = %s, file_size = %s, duration_seconds = %s,
|
||
vector_data = %s, updated_at = NOW()
|
||
WHERE user_id = %s
|
||
"""
|
||
cursor.execute(update_query, (
|
||
relative_path, file_size, duration,
|
||
json.dumps(vector_data) if vector_data else None,
|
||
user_id
|
||
))
|
||
vp_id = existing['vp_id']
|
||
else:
|
||
# 插入新记录
|
||
insert_query = """
|
||
INSERT INTO user_voiceprint
|
||
(user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at)
|
||
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
|
||
"""
|
||
cursor.execute(insert_query, (
|
||
user_id, relative_path, file_size, duration,
|
||
json.dumps(vector_data) if vector_data else None
|
||
))
|
||
vp_id = cursor.lastrowid
|
||
|
||
connection.commit()
|
||
|
||
return {
|
||
"vp_id": vp_id,
|
||
"user_id": user_id,
|
||
"file_path": relative_path,
|
||
"file_size": file_size,
|
||
"duration_seconds": duration,
|
||
"has_vector": vector_data is not None
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"保存声纹错误: {e}")
|
||
raise e
|
||
|
||
def delete_voiceprint(self, user_id: int) -> bool:
|
||
"""
|
||
删除用户声纹
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
"""
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# 获取文件路径
|
||
cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,))
|
||
voiceprint = cursor.fetchone()
|
||
|
||
if voiceprint:
|
||
# 构建完整文件路径
|
||
relative_path = voiceprint['file_path']
|
||
if relative_path.startswith('/'):
|
||
relative_path = relative_path.lstrip('/')
|
||
file_path = config_module.BASE_DIR / relative_path
|
||
|
||
# 删除数据库记录
|
||
cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,))
|
||
connection.commit()
|
||
|
||
# 删除文件
|
||
if file_path.exists():
|
||
os.remove(file_path)
|
||
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"删除声纹错误: {e}")
|
||
raise e
|
||
|
||
def _get_audio_duration(self, audio_file_path: str) -> float:
|
||
"""
|
||
获取音频文件时长
|
||
|
||
Args:
|
||
audio_file_path: 音频文件路径
|
||
|
||
Returns:
|
||
float: 时长(秒)
|
||
"""
|
||
try:
|
||
with wave.open(audio_file_path, 'rb') as wav_file:
|
||
frames = wav_file.getnframes()
|
||
rate = wav_file.getframerate()
|
||
duration = frames / float(rate)
|
||
return round(duration, 2)
|
||
except Exception as e:
|
||
print(f"获取音频时长错误: {e}")
|
||
return 10.0 # 默认返回10秒
|
||
|
||
def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]:
|
||
"""
|
||
提取声纹特征向量(调用FunASR)
|
||
|
||
Args:
|
||
audio_file_path: 音频文件路径
|
||
|
||
Returns:
|
||
Optional[list]: 声纹向量(192维),失败返回None
|
||
"""
|
||
# TODO: 集成FunASR的说话人识别模型
|
||
# 使用 speech_campplus_sv_zh-cn_16k-common 模型
|
||
# 返回192维的embedding向量
|
||
|
||
print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}")
|
||
|
||
# 暂时返回None,等待FunASR集成
|
||
# 集成后应该返回类似: [0.123, -0.456, 0.789, ...]
|
||
return None
|
||
|
||
|
||
# 创建全局实例
|
||
voiceprint_service = VoiceprintService()
|