imetting_backend/app/services/async_transcription_service.py

411 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import uuid
import json
import redis
import requests
from datetime import datetime
from typing import Optional, Dict, Any
from http import HTTPStatus
import dashscope
from dashscope.audio.asr import Transcription
from app.core.config import QWEN_API_KEY, REDIS_CONFIG, APP_CONFIG
from app.core.database import get_db_connection
class AsyncTranscriptionService:
"""异步转录服务类"""
def __init__(self):
dashscope.api_key = QWEN_API_KEY
self.redis_client = redis.Redis(**REDIS_CONFIG)
self.base_url = APP_CONFIG['base_url']
def start_transcription(self, meeting_id: int, audio_file_path: str) -> str:
"""
启动异步转录任务
Args:
meeting_id: 会议ID
audio_file_path: 音频文件相对路径
Returns:
str: 业务任务ID
"""
try:
# 构造完整的文件URL
file_url = f"{self.base_url}{audio_file_path}"
print(f"Starting transcription for meeting_id: {meeting_id}, file_url: {file_url}")
# 调用Paraformer异步API
task_response = Transcription.async_call(
model='paraformer-v2',
file_urls=[file_url],
language_hints=['zh', 'en'],
disfluency_removal_enabled=True,
diarization_enabled=True,
speaker_count=10
)
if task_response.status_code != HTTPStatus.OK:
print(f"Failed to start transcription: {task_response.status_code}, {task_response.message}")
raise Exception(f"Transcription API error: {task_response.message}")
paraformer_task_id = task_response.output.task_id
business_task_id = str(uuid.uuid4())
# 在Redis中存储任务映射
current_time = datetime.now().isoformat()
task_data = {
'business_task_id': business_task_id,
'paraformer_task_id': paraformer_task_id,
'meeting_id': str(meeting_id),
'file_url': file_url,
'status': 'pending',
'progress': '0',
'created_at': current_time,
'updated_at': current_time
}
# 存储到Redis过期时间24小时
self.redis_client.hset(f"task:{business_task_id}", mapping=task_data)
self.redis_client.expire(f"task:{business_task_id}", 86400)
# 在数据库中创建任务记录
self._save_task_to_db(business_task_id, paraformer_task_id, meeting_id, audio_file_path)
print(f"Transcription task created: {business_task_id}")
return business_task_id
except Exception as e:
print(f"Error starting transcription: {e}")
raise e
def get_task_status(self, business_task_id: str) -> Dict[str, Any]:
"""
获取任务状态
Args:
business_task_id: 业务任务ID
Returns:
Dict: 任务状态信息
"""
task_data = None
current_status = 'failed'
progress = 0
error_message = "An unknown error occurred."
try:
# 1. 获取任务数据优先Redis回源DB
task_data = self._get_task_data(business_task_id)
paraformer_task_id = task_data['paraformer_task_id']
# 2. 查询外部API获取状态
try:
paraformer_response = Transcription.fetch(task=paraformer_task_id)
if paraformer_response.status_code != HTTPStatus.OK:
raise Exception(f"Failed to fetch task status from provider: {paraformer_response.message}")
paraformer_status = paraformer_response.output.task_status
current_status = self._map_paraformer_status(paraformer_status)
progress = self._calculate_progress(paraformer_status)
error_message = None #执行成功,清除初始状态
except Exception as e:
current_status = 'failed'
progress = 0
error_message = f"Error fetching status from provider: {e}"
# 直接进入finally块更新状态后返回
return
# 3. 如果任务完成,处理结果
if current_status == 'completed' and paraformer_response.output.get('results'):
try:
self._process_transcription_result(
business_task_id,
int(task_data['meeting_id']),
paraformer_response.output
)
except Exception as e:
current_status = 'failed'
progress = 100 # 进度为100但状态是失败
error_message = f"Error processing transcription result: {e}"
print(error_message)
except Exception as e:
error_message = f"Error getting task status: {e}"
print(error_message)
current_status = 'failed'
progress = 0
finally:
# 4. 更新Redis和数据库状态
updated_at = datetime.now().isoformat()
# 更新Redis
update_data = {
'status': current_status,
'progress': str(progress),
'updated_at': updated_at
}
if error_message:
update_data['error_message'] = error_message
self.redis_client.hset(f"task:{business_task_id}", mapping=update_data)
# 更新数据库
self._update_task_status_in_db(business_task_id, current_status, progress, error_message)
# 5. 构造并返回最终结果
result = {
'task_id': business_task_id,
'status': current_status,
'progress': progress,
'error_message': error_message,
'updated_at': updated_at,
'meeting_id': None,
'created_at': None,
}
if task_data:
result['meeting_id'] = int(task_data['meeting_id'])
result['created_at'] = task_data.get('created_at')
return result
def _get_task_data(self, business_task_id: str) -> Dict[str, Any]:
"""从Redis或数据库获取任务数据"""
# 尝试从Redis获取
task_data_bytes = self.redis_client.hgetall(f"task:{business_task_id}")
if task_data_bytes and task_data_bytes.get(b'paraformer_task_id'):
# Redis返回的是bytes需要解码
return {k.decode('utf-8'): v.decode('utf-8') for k, v in task_data_bytes.items()}
# 如果Redis没有从数据库回源
task_data_from_db = self._get_task_from_db(business_task_id)
if not task_data_from_db or not task_data_from_db.get('paraformer_task_id'):
raise Exception("Task not found in DB or paraformer_task_id is missing")
# 将从DB获取的数据缓存回Redis
self.redis_client.hset(f"task:{business_task_id}", mapping=task_data_from_db)
self.redis_client.expire(f"task:{business_task_id}", 86400)
return task_data_from_db
def get_meeting_transcription_status(self, meeting_id: int) -> Optional[Dict[str, Any]]:
"""
获取会议的转录任务状态
Args:
meeting_id: 会议ID
Returns:
Optional[Dict]: 任务状态信息如果没有任务返回None
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 查询最新的转录任务
query = """
SELECT task_id, status, progress, created_at, completed_at, error_message
FROM transcript_tasks
WHERE meeting_id = %s
ORDER BY created_at DESC
LIMIT 1
"""
cursor.execute(query, (meeting_id,))
task_record = cursor.fetchone()
# 关闭游标
cursor.close()
if not task_record:
return None
# 如果任务还在进行中,获取最新状态
if task_record['status'] in ['pending', 'processing']:
try:
return self.get_task_status(task_record['task_id'])
except Exception as e:
print(f"Failed to get latest task status for meeting {meeting_id}, returning DB status. Error: {e}")
return {
'task_id': task_record['task_id'],
'status': task_record['status'],
'progress': task_record['progress'] or 0,
'meeting_id': meeting_id,
'created_at': task_record['created_at'].isoformat() if task_record['created_at'] else None,
'completed_at': task_record['completed_at'].isoformat() if task_record['completed_at'] else None,
'error_message': task_record['error_message']
}
except Exception as e:
print(f"Error getting meeting transcription status: {e}")
return None
def _map_paraformer_status(self, paraformer_status: str) -> str:
"""映射Paraformer状态到业务状态"""
status_mapping = {
'PENDING': 'pending',
'RUNNING': 'processing',
'SUCCEEDED': 'completed',
'FAILED': 'failed'
}
return status_mapping.get(paraformer_status, 'unknown')
def _calculate_progress(self, paraformer_status: str) -> int:
"""根据Paraformer状态计算进度"""
progress_mapping = {
'PENDING': 10,
'RUNNING': 50,
'SUCCEEDED': 100,
'FAILED': 0
}
return progress_mapping.get(paraformer_status, 0)
def _save_task_to_db(self, business_task_id: str, paraformer_task_id: str, meeting_id: int, audio_file_path: str):
"""保存任务记录到数据库"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
# 插入转录任务记录
insert_task_query = """
INSERT INTO transcript_tasks (task_id, paraformer_task_id, meeting_id, status, progress, created_at)
VALUES (%s, %s, %s, 'pending', 0, NOW())
"""
cursor.execute(insert_task_query, (business_task_id, paraformer_task_id, meeting_id))
connection.commit()
cursor.close()
except Exception as e:
print(f"Error saving task to database: {e}")
raise e
def _update_task_status_in_db(self, business_task_id: str, status: str, progress: int, error_message: Optional[str] = None):
"""更新数据库中的任务状态"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
params = [status, progress, error_message, business_task_id]
if status == 'completed':
update_query = """
UPDATE transcript_tasks
SET status = %s, progress = %s, completed_at = NOW(), error_message = %s
WHERE task_id = %s
"""
else:
update_query = """
UPDATE transcript_tasks
SET status = %s, progress = %s, error_message = %s
WHERE task_id = %s
"""
cursor.execute(update_query, tuple(params))
connection.commit()
cursor.close()
except Exception as e:
print(f"Error updating task status in database: {e}")
def _get_task_from_db(self, business_task_id: str) -> Optional[Dict[str, str]]:
"""从数据库获取任务信息"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = """
SELECT tt.task_id as business_task_id, tt.paraformer_task_id, tt.meeting_id, tt.status, tt.created_at
FROM transcript_tasks tt
WHERE tt.task_id = %s
"""
cursor.execute(query, (business_task_id,))
result = cursor.fetchone()
cursor.close()
if result:
# 转换为与Redis一致的字符串格式
return {
'business_task_id': result['business_task_id'],
'paraformer_task_id': result['paraformer_task_id'],
'meeting_id': str(result['meeting_id']),
'status': result['status'],
'created_at': result['created_at'].isoformat() if result['created_at'] else None
}
return None
except Exception as e:
print(f"Error getting task from database: {e}")
return None
def _process_transcription_result(self, business_task_id: str, meeting_id: int, paraformer_output: Any):
"""
处理转录结果.
如果处理失败,此函数会抛出异常.
"""
try:
if not paraformer_output.get('results'):
raise Exception("No transcription results found in the provider response.")
transcription_url = paraformer_output['results'][0]['transcription_url']
print(f"Fetching transcription from URL: {transcription_url}")
response = requests.get(transcription_url)
response.raise_for_status()
transcription_data = response.json()
# 保存转录内容到数据库
self._save_segments_to_db(transcription_data, meeting_id)
print(f"Transcription result processed for task: {business_task_id}")
except Exception as e:
# 记录具体错误并重新抛出,以便上层捕获
print(f"Error processing transcription result for task {business_task_id}: {e}")
raise
def _save_segments_to_db(self, data: dict, meeting_id: int):
"""保存转录分段到数据库"""
segments_to_insert = []
for transcript in data.get('transcripts', []):
for sentence in transcript.get('sentences', []):
speaker_id = sentence.get('speaker_id', -1)
segments_to_insert.append((
meeting_id,
speaker_id,
f"发言人 {speaker_id}", # 默认speaker_tag
sentence.get('begin_time'),
sentence.get('end_time'),
sentence.get('text')
))
if not segments_to_insert:
print("No segments to save.")
return
try:
with get_db_connection() as connection:
cursor = connection.cursor()
# 清除该会议的现有转录分段
delete_query = "DELETE FROM transcript_segments WHERE meeting_id = %s"
cursor.execute(delete_query, (meeting_id,))
print(f"Deleted existing segments for meeting_id: {meeting_id}")
# 插入新的转录分段
insert_query = '''
INSERT INTO transcript_segments (meeting_id, speaker_id, speaker_tag, start_time_ms, end_time_ms, text_content)
VALUES (%s, %s, %s, %s, %s, %s)
'''
cursor.executemany(insert_query, segments_to_insert)
connection.commit()
cursor.close()
print(f"Successfully saved {len(segments_to_insert)} segments to the database for meeting_id: {meeting_id}")
except Exception as e:
print(f"Database error when saving segments: {e}")
raise