imetting_backend/app/api/endpoints/hot_words.py

174 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastapi import APIRouter, Depends, HTTPException
from app.core.database import get_db_connection
from app.core.auth import get_current_admin_user
from app.core.response import create_api_response
from app.core.config import QWEN_API_KEY
from app.services.system_config_service import SystemConfigService
from pydantic import BaseModel
from typing import Optional, List
import json
import dashscope
from dashscope.audio.asr import VocabularyService
from datetime import datetime
from http import HTTPStatus
router = APIRouter()
class HotWordItem(BaseModel):
id: int
text: str
weight: int
lang: str
status: int
create_time: datetime
update_time: datetime
class CreateHotWordRequest(BaseModel):
text: str
weight: int = 4
lang: str = "zh"
status: int = 1
class UpdateHotWordRequest(BaseModel):
text: Optional[str] = None
weight: Optional[int] = None
lang: Optional[str] = None
status: Optional[int] = None
@router.get("/admin/hot-words", response_model=dict)
async def list_hot_words(current_user: dict = Depends(get_current_admin_user)):
"""获取热词列表"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM hot_words ORDER BY update_time DESC")
items = cursor.fetchall()
cursor.close()
return create_api_response(code="200", message="获取成功", data=items)
except Exception as e:
return create_api_response(code="500", message=f"获取失败: {str(e)}")
@router.post("/admin/hot-words", response_model=dict)
async def create_hot_word(request: CreateHotWordRequest, current_user: dict = Depends(get_current_admin_user)):
"""创建热词"""
try:
with get_db_connection() as conn:
cursor = conn.cursor()
query = "INSERT INTO hot_words (text, weight, lang, status) VALUES (%s, %s, %s, %s)"
cursor.execute(query, (request.text, request.weight, request.lang, request.status))
new_id = cursor.lastrowid
conn.commit()
cursor.close()
return create_api_response(code="200", message="创建成功", data={"id": new_id})
except Exception as e:
return create_api_response(code="500", message=f"创建失败: {str(e)}")
@router.put("/admin/hot-words/{id}", response_model=dict)
async def update_hot_word(id: int, request: UpdateHotWordRequest, current_user: dict = Depends(get_current_admin_user)):
"""更新热词"""
try:
with get_db_connection() as conn:
cursor = conn.cursor()
update_fields = []
params = []
if request.text is not None:
update_fields.append("text = %s")
params.append(request.text)
if request.weight is not None:
update_fields.append("weight = %s")
params.append(request.weight)
if request.lang is not None:
update_fields.append("lang = %s")
params.append(request.lang)
if request.status is not None:
update_fields.append("status = %s")
params.append(request.status)
if not update_fields:
return create_api_response(code="400", message="无更新内容")
query = f"UPDATE hot_words SET {', '.join(update_fields)} WHERE id = %s"
params.append(id)
cursor.execute(query, tuple(params))
conn.commit()
cursor.close()
return create_api_response(code="200", message="更新成功")
except Exception as e:
return create_api_response(code="500", message=f"更新失败: {str(e)}")
@router.delete("/admin/hot-words/{id}", response_model=dict)
async def delete_hot_word(id: int, current_user: dict = Depends(get_current_admin_user)):
"""删除热词"""
try:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM hot_words WHERE id = %s", (id,))
conn.commit()
cursor.close()
return create_api_response(code="200", message="删除成功")
except Exception as e:
return create_api_response(code="500", message=f"删除失败: {str(e)}")
@router.post("/admin/hot-words/sync", response_model=dict)
async def sync_hot_words(current_user: dict = Depends(get_current_admin_user)):
"""同步热词到阿里云 DashScope"""
try:
dashscope.api_key = QWEN_API_KEY
# 1. 获取所有启用的热词
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT text, weight, lang FROM hot_words WHERE status = 1")
hot_words = cursor.fetchall()
cursor.close()
# 2. 获取现有的 vocabulary_id
existing_vocab_id = SystemConfigService.get_asr_vocabulary_id()
# 构建热词列表
vocabulary_list = [{"text": hw['text'], "weight": hw['weight'], "lang": hw['lang']} for hw in hot_words]
if not vocabulary_list:
return create_api_response(code="400", message="没有启用的热词可同步")
# 3. 调用阿里云 API
service = VocabularyService()
vocab_id = existing_vocab_id
try:
if existing_vocab_id:
# 尝试更新现有的热词表
try:
service.update_vocabulary(
vocabulary_id=existing_vocab_id,
vocabulary=vocabulary_list
)
# 更新成功保持原有ID
except Exception as update_error:
# 如果更新失败(如资源不存在),尝试创建新的
print(f"Update vocabulary failed: {update_error}, trying to create new one.")
existing_vocab_id = None # 重置,触发创建逻辑
if not existing_vocab_id:
# 创建新的热词表
vocab_id = service.create_vocabulary(
prefix='imeeting',
target_model='paraformer-v2',
vocabulary=vocabulary_list
)
except Exception as api_error:
return create_api_response(code="500", message=f"同步到阿里云失败: {str(api_error)}")
# 4. 更新数据库中的 vocabulary_id
if vocab_id:
SystemConfigService.set_config(
SystemConfigService.ASR_VOCABULARY_ID,
vocab_id
)
return create_api_response(code="200", message="同步成功", data={"vocabulary_id": vocab_id})
except Exception as e:
return create_api_response(code="500", message=f"同步异常: {str(e)}")