""" 声纹服务 - 处理用户声纹采集、存储和验证 """ import os import json import wave from datetime import datetime from typing import Optional, Dict from pathlib import Path from app.core.database import get_db_connection import app.core.config as config_module class VoiceprintService: """声纹服务类 - 同步处理声纹采集""" def __init__(self): self.voiceprint_dir = config_module.VOICEPRINT_DIR self.voiceprint_config = config_module.VOICEPRINT_CONFIG def get_user_voiceprint_status(self, user_id: int) -> Dict: """ 获取用户声纹状态 Args: user_id: 用户ID Returns: Dict: 声纹状态信息 """ try: with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) query = """ SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at FROM user_voiceprint WHERE user_id = %s """ cursor.execute(query, (user_id,)) voiceprint = cursor.fetchone() if voiceprint: return { "has_voiceprint": True, "vp_id": voiceprint['vp_id'], "file_path": voiceprint['file_path'], "duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None, "collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None } else: return { "has_voiceprint": False, "vp_id": None, "file_path": None, "duration_seconds": None, "collected_at": None } except Exception as e: print(f"获取声纹状态错误: {e}") raise e def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict: """ 保存声纹文件并提取特征向量 Args: user_id: 用户ID audio_file_path: 音频文件路径 file_size: 文件大小 Returns: Dict: 保存结果 """ try: # 1. 获取音频时长 duration = self._get_audio_duration(audio_file_path) # 2. 提取声纹向量(调用FunASR) vector_data = self._extract_voiceprint_vector(audio_file_path) # 3. 保存到数据库 with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) # 检查用户是否已有声纹 cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,)) existing = cursor.fetchone() # 计算相对路径 relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR)) if existing: # 更新现有记录 update_query = """ UPDATE user_voiceprint SET file_path = %s, file_size = %s, duration_seconds = %s, vector_data = %s, updated_at = NOW() WHERE user_id = %s """ cursor.execute(update_query, ( relative_path, file_size, duration, json.dumps(vector_data) if vector_data else None, user_id )) vp_id = existing['vp_id'] else: # 插入新记录 insert_query = """ INSERT INTO user_voiceprint (user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at) VALUES (%s, %s, %s, %s, %s, NOW(), NOW()) """ cursor.execute(insert_query, ( user_id, relative_path, file_size, duration, json.dumps(vector_data) if vector_data else None )) vp_id = cursor.lastrowid connection.commit() return { "vp_id": vp_id, "user_id": user_id, "file_path": relative_path, "file_size": file_size, "duration_seconds": duration, "has_vector": vector_data is not None } except Exception as e: print(f"保存声纹错误: {e}") raise e def delete_voiceprint(self, user_id: int) -> bool: """ 删除用户声纹 Args: user_id: 用户ID Returns: bool: 是否删除成功 """ try: with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) # 获取文件路径 cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,)) voiceprint = cursor.fetchone() if voiceprint: # 构建完整文件路径 relative_path = voiceprint['file_path'] if relative_path.startswith('/'): relative_path = relative_path.lstrip('/') file_path = config_module.BASE_DIR / relative_path # 删除数据库记录 cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,)) connection.commit() # 删除文件 if file_path.exists(): os.remove(file_path) return True else: return False except Exception as e: print(f"删除声纹错误: {e}") raise e def _get_audio_duration(self, audio_file_path: str) -> float: """ 获取音频文件时长 Args: audio_file_path: 音频文件路径 Returns: float: 时长(秒) """ try: with wave.open(audio_file_path, 'rb') as wav_file: frames = wav_file.getnframes() rate = wav_file.getframerate() duration = frames / float(rate) return round(duration, 2) except Exception as e: print(f"获取音频时长错误: {e}") return 10.0 # 默认返回10秒 def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]: """ 提取声纹特征向量(调用FunASR) Args: audio_file_path: 音频文件路径 Returns: Optional[list]: 声纹向量(192维),失败返回None """ # TODO: 集成FunASR的说话人识别模型 # 使用 speech_campplus_sv_zh-cn_16k-common 模型 # 返回192维的embedding向量 print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}") # 暂时返回None,等待FunASR集成 # 集成后应该返回类似: [0.123, -0.456, 0.789, ...] return None # 创建全局实例 voiceprint_service = VoiceprintService()