imetting_backend/app/services/async_llm_service.py

216 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
异步LLM服务 - 处理会议总结生成的异步任务
采用FastAPI BackgroundTasks模式
"""
import uuid
import time
from datetime import datetime
from typing import Optional, Dict, Any, List
import redis
from app.core.config import REDIS_CONFIG
from app.core.database import get_db_connection
from app.services.llm_service import LLMService
class AsyncLLMService:
"""异步LLM服务类 - 采用FastAPI BackgroundTasks模式"""
def __init__(self):
# 确保redis客户端自动解码响应代码更简洁
if 'decode_responses' not in REDIS_CONFIG:
REDIS_CONFIG['decode_responses'] = True
self.redis_client = redis.Redis(**REDIS_CONFIG)
self.llm_service = LLMService() # 复用现有的同步LLM服务
def start_summary_generation(self, meeting_id: int, user_prompt: str = "") -> str:
"""
创建异步总结任务任务的执行将由外部如API层的BackgroundTasks触发。
Args:
meeting_id: 会议ID
user_prompt: 用户额外提示词
Returns:
str: 任务ID
"""
try:
task_id = str(uuid.uuid4())
# 在数据库中创建任务记录
self._save_task_to_db(task_id, meeting_id, user_prompt)
# 将任务详情存入Redis用于快速查询状态
current_time = datetime.now().isoformat()
task_data = {
'task_id': task_id,
'meeting_id': str(meeting_id),
'user_prompt': user_prompt,
'status': 'pending',
'progress': '0',
'created_at': current_time,
'updated_at': current_time
}
self.redis_client.hset(f"llm_task:{task_id}", mapping=task_data)
self.redis_client.expire(f"llm_task:{task_id}", 86400)
print(f"LLM summary task created: {task_id} for meeting: {meeting_id}")
return task_id
except Exception as e:
print(f"Error starting summary generation: {e}")
raise e
def _process_task(self, task_id: str):
"""
处理单个异步任务的函数设计为由BackgroundTasks调用。
"""
print(f"Background task started for LLM task: {task_id}")
try:
# 从Redis获取任务数据
task_data = self.redis_client.hgetall(f"llm_task:{task_id}")
if not task_data:
print(f"Error: Task {task_id} not found in Redis for processing.")
return
meeting_id = int(task_data['meeting_id'])
user_prompt = task_data.get('user_prompt', '')
# 1. 更新状态为processing
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
# 2. 获取会议转录内容
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取会议转录内容...")
transcript_text = self.llm_service._get_meeting_transcript(meeting_id)
if not transcript_text:
raise Exception("无法获取会议转录内容")
# 3. 构建提示词
self._update_task_status_in_redis(task_id, 'processing', 40, message="准备AI提示词...")
full_prompt = self.llm_service._build_prompt(transcript_text, user_prompt)
# 4. 调用LLM API
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在分析会议内容...")
summary_content = self.llm_service._call_llm_api(full_prompt)
if not summary_content:
raise Exception("LLM API调用失败或返回空内容")
# 5. 保存结果到主表
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存总结结果...")
self.llm_service._save_summary_to_db(meeting_id, summary_content, user_prompt)
# 6. 任务完成
self._update_task_in_db(task_id, 'completed', 100, result=summary_content)
self._update_task_status_in_redis(task_id, 'completed', 100, result=summary_content)
print(f"Task {task_id} completed successfully")
except Exception as e:
error_msg = str(e)
print(f"Task {task_id} failed: {error_msg}")
# 更新失败状态
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
# --- 状态查询和数据库操作方法 ---
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""获取任务状态"""
try:
task_data = self.redis_client.hgetall(f"llm_task:{task_id}")
if not task_data:
task_data = self._get_task_from_db(task_id)
if not task_data:
return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'}
return {
'task_id': task_id,
'status': task_data.get('status', 'unknown'),
'progress': int(task_data.get('progress', 0)),
'meeting_id': int(task_data.get('meeting_id', 0)),
'created_at': task_data.get('created_at'),
'updated_at': task_data.get('updated_at'),
'result': task_data.get('result'),
'error_message': task_data.get('error_message')
}
except Exception as e:
print(f"Error getting task status: {e}")
return {'task_id': task_id, 'status': 'error', 'error_message': str(e)}
def get_meeting_llm_tasks(self, meeting_id: int) -> List[Dict[str, Any]]:
"""获取会议的所有LLM任务"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = "SELECT task_id, status, progress, user_prompt, created_at, completed_at, error_message FROM llm_tasks WHERE meeting_id = %s ORDER BY created_at DESC"
cursor.execute(query, (meeting_id,))
tasks = cursor.fetchall()
for task in tasks:
if task.get('created_at'): task['created_at'] = task['created_at'].isoformat()
if task.get('completed_at'): task['completed_at'] = task['completed_at'].isoformat()
return tasks
except Exception as e:
print(f"Error getting meeting LLM tasks: {e}")
return []
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, result: str = None, error_message: str = None):
"""更新Redis中的任务状态"""
try:
update_data = {
'status': status,
'progress': str(progress),
'updated_at': datetime.now().isoformat()
}
if message: update_data['message'] = message
if result: update_data['result'] = result
if error_message: update_data['error_message'] = error_message
self.redis_client.hset(f"llm_task:{task_id}", mapping=update_data)
except Exception as e:
print(f"Error updating task status in Redis: {e}")
def _save_task_to_db(self, task_id: str, meeting_id: int, user_prompt: str):
"""保存任务到数据库"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
insert_query = "INSERT INTO llm_tasks (task_id, meeting_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, 'pending', 0, NOW())"
cursor.execute(insert_query, (task_id, meeting_id, user_prompt))
connection.commit()
except Exception as e:
print(f"Error saving task to database: {e}")
raise
def _update_task_in_db(self, task_id: str, status: str, progress: int, result: str = None, error_message: str = None):
"""更新数据库中的任务状态"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
params = [status, progress, error_message, task_id]
if status == 'completed':
query = "UPDATE llm_tasks SET status = %s, progress = %s, error_message = %s, result = %s, completed_at = NOW() WHERE task_id = %s"
params.insert(2, result)
else:
query = "UPDATE llm_tasks SET status = %s, progress = %s, error_message = %s WHERE task_id = %s"
cursor.execute(query, tuple(params))
connection.commit()
except Exception as e:
print(f"Error updating task in database: {e}")
def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]:
"""从数据库获取任务信息"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = "SELECT * FROM llm_tasks WHERE task_id = %s"
cursor.execute(query, (task_id,))
task = cursor.fetchone()
if task:
# 确保所有字段都是字符串以匹配Redis的行为
return {k: v.isoformat() if isinstance(v, datetime) else str(v) for k, v in task.items()}
return None
except Exception as e:
print(f"Error getting task from database: {e}")
return None
# 创建全局实例
async_llm_service = AsyncLLMService()