imetting_backend/app/services/voiceprint_service.py

219 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
声纹服务 - 处理用户声纹采集、存储和验证
"""
import os
import json
import wave
from datetime import datetime
from typing import Optional, Dict
from pathlib import Path
from app.core.database import get_db_connection
import app.core.config as config_module
class VoiceprintService:
"""声纹服务类 - 同步处理声纹采集"""
def __init__(self):
self.voiceprint_dir = config_module.VOICEPRINT_DIR
self.voiceprint_config = config_module.VOICEPRINT_CONFIG
def get_user_voiceprint_status(self, user_id: int) -> Dict:
"""
获取用户声纹状态
Args:
user_id: 用户ID
Returns:
Dict: 声纹状态信息
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = """
SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at
FROM user_voiceprint
WHERE user_id = %s
"""
cursor.execute(query, (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
return {
"has_voiceprint": True,
"vp_id": voiceprint['vp_id'],
"file_path": voiceprint['file_path'],
"duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None,
"collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None
}
else:
return {
"has_voiceprint": False,
"vp_id": None,
"file_path": None,
"duration_seconds": None,
"collected_at": None
}
except Exception as e:
print(f"获取声纹状态错误: {e}")
raise e
def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict:
"""
保存声纹文件并提取特征向量
Args:
user_id: 用户ID
audio_file_path: 音频文件路径
file_size: 文件大小
Returns:
Dict: 保存结果
"""
try:
# 1. 获取音频时长
duration = self._get_audio_duration(audio_file_path)
# 2. 提取声纹向量调用FunASR
vector_data = self._extract_voiceprint_vector(audio_file_path)
# 3. 保存到数据库
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 检查用户是否已有声纹
cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,))
existing = cursor.fetchone()
# 计算相对路径
relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR))
if existing:
# 更新现有记录
update_query = """
UPDATE user_voiceprint
SET file_path = %s, file_size = %s, duration_seconds = %s,
vector_data = %s, updated_at = NOW()
WHERE user_id = %s
"""
cursor.execute(update_query, (
relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None,
user_id
))
vp_id = existing['vp_id']
else:
# 插入新记录
insert_query = """
INSERT INTO user_voiceprint
(user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at)
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
"""
cursor.execute(insert_query, (
user_id, relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None
))
vp_id = cursor.lastrowid
connection.commit()
return {
"vp_id": vp_id,
"user_id": user_id,
"file_path": relative_path,
"file_size": file_size,
"duration_seconds": duration,
"has_vector": vector_data is not None
}
except Exception as e:
print(f"保存声纹错误: {e}")
raise e
def delete_voiceprint(self, user_id: int) -> bool:
"""
删除用户声纹
Args:
user_id: 用户ID
Returns:
bool: 是否删除成功
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取文件路径
cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
# 构建完整文件路径
relative_path = voiceprint['file_path']
if relative_path.startswith('/'):
relative_path = relative_path.lstrip('/')
file_path = config_module.BASE_DIR / relative_path
# 删除数据库记录
cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,))
connection.commit()
# 删除文件
if file_path.exists():
os.remove(file_path)
return True
else:
return False
except Exception as e:
print(f"删除声纹错误: {e}")
raise e
def _get_audio_duration(self, audio_file_path: str) -> float:
"""
获取音频文件时长
Args:
audio_file_path: 音频文件路径
Returns:
float: 时长(秒)
"""
try:
with wave.open(audio_file_path, 'rb') as wav_file:
frames = wav_file.getnframes()
rate = wav_file.getframerate()
duration = frames / float(rate)
return round(duration, 2)
except Exception as e:
print(f"获取音频时长错误: {e}")
return 10.0 # 默认返回10秒
def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]:
"""
提取声纹特征向量调用FunASR
Args:
audio_file_path: 音频文件路径
Returns:
Optional[list]: 声纹向量192维失败返回None
"""
# TODO: 集成FunASR的说话人识别模型
# 使用 speech_campplus_sv_zh-cn_16k-common 模型
# 返回192维的embedding向量
print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}")
# 暂时返回None等待FunASR集成
# 集成后应该返回类似: [0.123, -0.456, 0.789, ...]
return None
# 创建全局实例
voiceprint_service = VoiceprintService()