411 lines
16 KiB
Python
411 lines
16 KiB
Python
import uuid
|
||
import json
|
||
import redis
|
||
import requests
|
||
from datetime import datetime
|
||
from typing import Optional, Dict, Any
|
||
from http import HTTPStatus
|
||
|
||
import dashscope
|
||
from dashscope.audio.asr import Transcription
|
||
|
||
from app.core.config import QWEN_API_KEY, REDIS_CONFIG, APP_CONFIG
|
||
from app.core.database import get_db_connection
|
||
|
||
|
||
class AsyncTranscriptionService:
|
||
"""异步转录服务类"""
|
||
|
||
def __init__(self):
|
||
dashscope.api_key = QWEN_API_KEY
|
||
self.redis_client = redis.Redis(**REDIS_CONFIG)
|
||
self.base_url = APP_CONFIG['base_url']
|
||
|
||
def start_transcription(self, meeting_id: int, audio_file_path: str) -> str:
|
||
"""
|
||
启动异步转录任务
|
||
|
||
Args:
|
||
meeting_id: 会议ID
|
||
audio_file_path: 音频文件相对路径
|
||
|
||
Returns:
|
||
str: 业务任务ID
|
||
"""
|
||
try:
|
||
# 构造完整的文件URL
|
||
file_url = f"{self.base_url}{audio_file_path}"
|
||
|
||
print(f"Starting transcription for meeting_id: {meeting_id}, file_url: {file_url}")
|
||
|
||
# 调用Paraformer异步API
|
||
task_response = Transcription.async_call(
|
||
model='paraformer-v2',
|
||
file_urls=[file_url],
|
||
language_hints=['zh', 'en'],
|
||
disfluency_removal_enabled=True,
|
||
diarization_enabled=True,
|
||
speaker_count=10
|
||
)
|
||
|
||
if task_response.status_code != HTTPStatus.OK:
|
||
print(f"Failed to start transcription: {task_response.status_code}, {task_response.message}")
|
||
raise Exception(f"Transcription API error: {task_response.message}")
|
||
|
||
paraformer_task_id = task_response.output.task_id
|
||
business_task_id = str(uuid.uuid4())
|
||
|
||
# 在Redis中存储任务映射
|
||
current_time = datetime.now().isoformat()
|
||
task_data = {
|
||
'business_task_id': business_task_id,
|
||
'paraformer_task_id': paraformer_task_id,
|
||
'meeting_id': str(meeting_id),
|
||
'file_url': file_url,
|
||
'status': 'pending',
|
||
'progress': '0',
|
||
'created_at': current_time,
|
||
'updated_at': current_time
|
||
}
|
||
|
||
# 存储到Redis,过期时间24小时
|
||
self.redis_client.hset(f"task:{business_task_id}", mapping=task_data)
|
||
self.redis_client.expire(f"task:{business_task_id}", 86400)
|
||
|
||
# 在数据库中创建任务记录
|
||
self._save_task_to_db(business_task_id, paraformer_task_id, meeting_id, audio_file_path)
|
||
|
||
print(f"Transcription task created: {business_task_id}")
|
||
return business_task_id
|
||
|
||
except Exception as e:
|
||
print(f"Error starting transcription: {e}")
|
||
raise e
|
||
|
||
def get_task_status(self, business_task_id: str) -> Dict[str, Any]:
|
||
"""
|
||
获取任务状态
|
||
|
||
Args:
|
||
business_task_id: 业务任务ID
|
||
|
||
Returns:
|
||
Dict: 任务状态信息
|
||
"""
|
||
task_data = None
|
||
current_status = 'failed'
|
||
progress = 0
|
||
error_message = "An unknown error occurred."
|
||
|
||
try:
|
||
# 1. 获取任务数据(优先Redis,回源DB)
|
||
task_data = self._get_task_data(business_task_id)
|
||
paraformer_task_id = task_data['paraformer_task_id']
|
||
|
||
# 2. 查询外部API获取状态
|
||
try:
|
||
paraformer_response = Transcription.fetch(task=paraformer_task_id)
|
||
if paraformer_response.status_code != HTTPStatus.OK:
|
||
raise Exception(f"Failed to fetch task status from provider: {paraformer_response.message}")
|
||
|
||
paraformer_status = paraformer_response.output.task_status
|
||
current_status = self._map_paraformer_status(paraformer_status)
|
||
progress = self._calculate_progress(paraformer_status)
|
||
error_message = None #执行成功,清除初始状态
|
||
|
||
except Exception as e:
|
||
current_status = 'failed'
|
||
progress = 0
|
||
error_message = f"Error fetching status from provider: {e}"
|
||
# 直接进入finally块更新状态后返回
|
||
return
|
||
|
||
# 3. 如果任务完成,处理结果
|
||
if current_status == 'completed' and paraformer_response.output.get('results'):
|
||
try:
|
||
self._process_transcription_result(
|
||
business_task_id,
|
||
int(task_data['meeting_id']),
|
||
paraformer_response.output
|
||
)
|
||
except Exception as e:
|
||
current_status = 'failed'
|
||
progress = 100 # 进度为100,但状态是失败
|
||
error_message = f"Error processing transcription result: {e}"
|
||
print(error_message)
|
||
|
||
except Exception as e:
|
||
error_message = f"Error getting task status: {e}"
|
||
print(error_message)
|
||
current_status = 'failed'
|
||
progress = 0
|
||
|
||
finally:
|
||
# 4. 更新Redis和数据库状态
|
||
updated_at = datetime.now().isoformat()
|
||
|
||
# 更新Redis
|
||
update_data = {
|
||
'status': current_status,
|
||
'progress': str(progress),
|
||
'updated_at': updated_at
|
||
}
|
||
if error_message:
|
||
update_data['error_message'] = error_message
|
||
self.redis_client.hset(f"task:{business_task_id}", mapping=update_data)
|
||
|
||
# 更新数据库
|
||
self._update_task_status_in_db(business_task_id, current_status, progress, error_message)
|
||
|
||
# 5. 构造并返回最终结果
|
||
result = {
|
||
'task_id': business_task_id,
|
||
'status': current_status,
|
||
'progress': progress,
|
||
'error_message': error_message,
|
||
'updated_at': updated_at,
|
||
'meeting_id': None,
|
||
'created_at': None,
|
||
}
|
||
if task_data:
|
||
result['meeting_id'] = int(task_data['meeting_id'])
|
||
result['created_at'] = task_data.get('created_at')
|
||
|
||
return result
|
||
|
||
def _get_task_data(self, business_task_id: str) -> Dict[str, Any]:
|
||
"""从Redis或数据库获取任务数据"""
|
||
# 尝试从Redis获取
|
||
task_data_bytes = self.redis_client.hgetall(f"task:{business_task_id}")
|
||
if task_data_bytes and task_data_bytes.get(b'paraformer_task_id'):
|
||
# Redis返回的是bytes,需要解码
|
||
return {k.decode('utf-8'): v.decode('utf-8') for k, v in task_data_bytes.items()}
|
||
|
||
# 如果Redis没有,从数据库回源
|
||
task_data_from_db = self._get_task_from_db(business_task_id)
|
||
if not task_data_from_db or not task_data_from_db.get('paraformer_task_id'):
|
||
raise Exception("Task not found in DB or paraformer_task_id is missing")
|
||
|
||
# 将从DB获取的数据缓存回Redis
|
||
self.redis_client.hset(f"task:{business_task_id}", mapping=task_data_from_db)
|
||
self.redis_client.expire(f"task:{business_task_id}", 86400)
|
||
|
||
return task_data_from_db
|
||
|
||
def get_meeting_transcription_status(self, meeting_id: int) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取会议的转录任务状态
|
||
|
||
Args:
|
||
meeting_id: 会议ID
|
||
|
||
Returns:
|
||
Optional[Dict]: 任务状态信息,如果没有任务返回None
|
||
"""
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# 查询最新的转录任务
|
||
query = """
|
||
SELECT task_id, status, progress, created_at, completed_at, error_message
|
||
FROM transcript_tasks
|
||
WHERE meeting_id = %s
|
||
ORDER BY created_at DESC
|
||
LIMIT 1
|
||
"""
|
||
cursor.execute(query, (meeting_id,))
|
||
task_record = cursor.fetchone()
|
||
|
||
# 关闭游标
|
||
cursor.close()
|
||
|
||
if not task_record:
|
||
return None
|
||
|
||
# 如果任务还在进行中,获取最新状态
|
||
if task_record['status'] in ['pending', 'processing']:
|
||
try:
|
||
return self.get_task_status(task_record['task_id'])
|
||
except Exception as e:
|
||
print(f"Failed to get latest task status for meeting {meeting_id}, returning DB status. Error: {e}")
|
||
|
||
return {
|
||
'task_id': task_record['task_id'],
|
||
'status': task_record['status'],
|
||
'progress': task_record['progress'] or 0,
|
||
'meeting_id': meeting_id,
|
||
'created_at': task_record['created_at'].isoformat() if task_record['created_at'] else None,
|
||
'completed_at': task_record['completed_at'].isoformat() if task_record['completed_at'] else None,
|
||
'error_message': task_record['error_message']
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting meeting transcription status: {e}")
|
||
return None
|
||
|
||
def _map_paraformer_status(self, paraformer_status: str) -> str:
|
||
"""映射Paraformer状态到业务状态"""
|
||
status_mapping = {
|
||
'PENDING': 'pending',
|
||
'RUNNING': 'processing',
|
||
'SUCCEEDED': 'completed',
|
||
'FAILED': 'failed'
|
||
}
|
||
return status_mapping.get(paraformer_status, 'unknown')
|
||
|
||
def _calculate_progress(self, paraformer_status: str) -> int:
|
||
"""根据Paraformer状态计算进度"""
|
||
progress_mapping = {
|
||
'PENDING': 10,
|
||
'RUNNING': 50,
|
||
'SUCCEEDED': 100,
|
||
'FAILED': 0
|
||
}
|
||
return progress_mapping.get(paraformer_status, 0)
|
||
|
||
def _save_task_to_db(self, business_task_id: str, paraformer_task_id: str, meeting_id: int, audio_file_path: str):
|
||
"""保存任务记录到数据库"""
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor()
|
||
|
||
# 插入转录任务记录
|
||
insert_task_query = """
|
||
INSERT INTO transcript_tasks (task_id, paraformer_task_id, meeting_id, status, progress, created_at)
|
||
VALUES (%s, %s, %s, 'pending', 0, NOW())
|
||
"""
|
||
cursor.execute(insert_task_query, (business_task_id, paraformer_task_id, meeting_id))
|
||
|
||
connection.commit()
|
||
cursor.close()
|
||
|
||
except Exception as e:
|
||
print(f"Error saving task to database: {e}")
|
||
raise e
|
||
|
||
def _update_task_status_in_db(self, business_task_id: str, status: str, progress: int, error_message: Optional[str] = None):
|
||
"""更新数据库中的任务状态"""
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor()
|
||
|
||
params = [status, progress, error_message, business_task_id]
|
||
if status == 'completed':
|
||
update_query = """
|
||
UPDATE transcript_tasks
|
||
SET status = %s, progress = %s, completed_at = NOW(), error_message = %s
|
||
WHERE task_id = %s
|
||
"""
|
||
else:
|
||
update_query = """
|
||
UPDATE transcript_tasks
|
||
SET status = %s, progress = %s, error_message = %s
|
||
WHERE task_id = %s
|
||
"""
|
||
|
||
cursor.execute(update_query, tuple(params))
|
||
connection.commit()
|
||
cursor.close()
|
||
|
||
except Exception as e:
|
||
print(f"Error updating task status in database: {e}")
|
||
|
||
def _get_task_from_db(self, business_task_id: str) -> Optional[Dict[str, str]]:
|
||
"""从数据库获取任务信息"""
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
query = """
|
||
SELECT tt.task_id as business_task_id, tt.paraformer_task_id, tt.meeting_id, tt.status, tt.created_at
|
||
FROM transcript_tasks tt
|
||
WHERE tt.task_id = %s
|
||
"""
|
||
cursor.execute(query, (business_task_id,))
|
||
result = cursor.fetchone()
|
||
cursor.close()
|
||
|
||
if result:
|
||
# 转换为与Redis一致的字符串格式
|
||
return {
|
||
'business_task_id': result['business_task_id'],
|
||
'paraformer_task_id': result['paraformer_task_id'],
|
||
'meeting_id': str(result['meeting_id']),
|
||
'status': result['status'],
|
||
'created_at': result['created_at'].isoformat() if result['created_at'] else None
|
||
}
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"Error getting task from database: {e}")
|
||
return None
|
||
|
||
def _process_transcription_result(self, business_task_id: str, meeting_id: int, paraformer_output: Any):
|
||
"""
|
||
处理转录结果.
|
||
如果处理失败,此函数会抛出异常.
|
||
"""
|
||
try:
|
||
if not paraformer_output.get('results'):
|
||
raise Exception("No transcription results found in the provider response.")
|
||
|
||
transcription_url = paraformer_output['results'][0]['transcription_url']
|
||
print(f"Fetching transcription from URL: {transcription_url}")
|
||
|
||
response = requests.get(transcription_url)
|
||
response.raise_for_status()
|
||
transcription_data = response.json()
|
||
|
||
# 保存转录内容到数据库
|
||
self._save_segments_to_db(transcription_data, meeting_id)
|
||
|
||
print(f"Transcription result processed for task: {business_task_id}")
|
||
|
||
except Exception as e:
|
||
# 记录具体错误并重新抛出,以便上层捕获
|
||
print(f"Error processing transcription result for task {business_task_id}: {e}")
|
||
raise
|
||
|
||
def _save_segments_to_db(self, data: dict, meeting_id: int):
|
||
"""保存转录分段到数据库"""
|
||
segments_to_insert = []
|
||
for transcript in data.get('transcripts', []):
|
||
for sentence in transcript.get('sentences', []):
|
||
speaker_id = sentence.get('speaker_id', -1)
|
||
segments_to_insert.append((
|
||
meeting_id,
|
||
speaker_id,
|
||
f"发言人 {speaker_id}", # 默认speaker_tag
|
||
sentence.get('begin_time'),
|
||
sentence.get('end_time'),
|
||
sentence.get('text')
|
||
))
|
||
|
||
if not segments_to_insert:
|
||
print("No segments to save.")
|
||
return
|
||
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor()
|
||
|
||
# 清除该会议的现有转录分段
|
||
delete_query = "DELETE FROM transcript_segments WHERE meeting_id = %s"
|
||
cursor.execute(delete_query, (meeting_id,))
|
||
print(f"Deleted existing segments for meeting_id: {meeting_id}")
|
||
|
||
# 插入新的转录分段
|
||
insert_query = '''
|
||
INSERT INTO transcript_segments (meeting_id, speaker_id, speaker_tag, start_time_ms, end_time_ms, text_content)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
'''
|
||
cursor.executemany(insert_query, segments_to_insert)
|
||
connection.commit()
|
||
cursor.close()
|
||
print(f"Successfully saved {len(segments_to_insert)} segments to the database for meeting_id: {meeting_id}")
|
||
|
||
except Exception as e:
|
||
print(f"Database error when saving segments: {e}")
|
||
raise
|