174 lines
6.9 KiB
Python
174 lines
6.9 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
||
from app.core.database import get_db_connection
|
||
from app.core.auth import get_current_admin_user
|
||
from app.core.response import create_api_response
|
||
from app.core.config import QWEN_API_KEY
|
||
from app.services.system_config_service import SystemConfigService
|
||
from pydantic import BaseModel
|
||
from typing import Optional, List
|
||
import json
|
||
import dashscope
|
||
from dashscope.audio.asr import VocabularyService
|
||
from datetime import datetime
|
||
from http import HTTPStatus
|
||
|
||
router = APIRouter()
|
||
|
||
class HotWordItem(BaseModel):
|
||
id: int
|
||
text: str
|
||
weight: int
|
||
lang: str
|
||
status: int
|
||
create_time: datetime
|
||
update_time: datetime
|
||
|
||
class CreateHotWordRequest(BaseModel):
|
||
text: str
|
||
weight: int = 4
|
||
lang: str = "zh"
|
||
status: int = 1
|
||
|
||
class UpdateHotWordRequest(BaseModel):
|
||
text: Optional[str] = None
|
||
weight: Optional[int] = None
|
||
lang: Optional[str] = None
|
||
status: Optional[int] = None
|
||
|
||
@router.get("/admin/hot-words", response_model=dict)
|
||
async def list_hot_words(current_user: dict = Depends(get_current_admin_user)):
|
||
"""获取热词列表"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
cursor.execute("SELECT * FROM hot_words ORDER BY update_time DESC")
|
||
items = cursor.fetchall()
|
||
cursor.close()
|
||
return create_api_response(code="200", message="获取成功", data=items)
|
||
except Exception as e:
|
||
return create_api_response(code="500", message=f"获取失败: {str(e)}")
|
||
|
||
@router.post("/admin/hot-words", response_model=dict)
|
||
async def create_hot_word(request: CreateHotWordRequest, current_user: dict = Depends(get_current_admin_user)):
|
||
"""创建热词"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
query = "INSERT INTO hot_words (text, weight, lang, status) VALUES (%s, %s, %s, %s)"
|
||
cursor.execute(query, (request.text, request.weight, request.lang, request.status))
|
||
new_id = cursor.lastrowid
|
||
conn.commit()
|
||
cursor.close()
|
||
return create_api_response(code="200", message="创建成功", data={"id": new_id})
|
||
except Exception as e:
|
||
return create_api_response(code="500", message=f"创建失败: {str(e)}")
|
||
|
||
@router.put("/admin/hot-words/{id}", response_model=dict)
|
||
async def update_hot_word(id: int, request: UpdateHotWordRequest, current_user: dict = Depends(get_current_admin_user)):
|
||
"""更新热词"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
update_fields = []
|
||
params = []
|
||
if request.text is not None:
|
||
update_fields.append("text = %s")
|
||
params.append(request.text)
|
||
if request.weight is not None:
|
||
update_fields.append("weight = %s")
|
||
params.append(request.weight)
|
||
if request.lang is not None:
|
||
update_fields.append("lang = %s")
|
||
params.append(request.lang)
|
||
if request.status is not None:
|
||
update_fields.append("status = %s")
|
||
params.append(request.status)
|
||
|
||
if not update_fields:
|
||
return create_api_response(code="400", message="无更新内容")
|
||
|
||
query = f"UPDATE hot_words SET {', '.join(update_fields)} WHERE id = %s"
|
||
params.append(id)
|
||
cursor.execute(query, tuple(params))
|
||
conn.commit()
|
||
cursor.close()
|
||
return create_api_response(code="200", message="更新成功")
|
||
except Exception as e:
|
||
return create_api_response(code="500", message=f"更新失败: {str(e)}")
|
||
|
||
@router.delete("/admin/hot-words/{id}", response_model=dict)
|
||
async def delete_hot_word(id: int, current_user: dict = Depends(get_current_admin_user)):
|
||
"""删除热词"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("DELETE FROM hot_words WHERE id = %s", (id,))
|
||
conn.commit()
|
||
cursor.close()
|
||
return create_api_response(code="200", message="删除成功")
|
||
except Exception as e:
|
||
return create_api_response(code="500", message=f"删除失败: {str(e)}")
|
||
|
||
@router.post("/admin/hot-words/sync", response_model=dict)
|
||
async def sync_hot_words(current_user: dict = Depends(get_current_admin_user)):
|
||
"""同步热词到阿里云 DashScope"""
|
||
try:
|
||
dashscope.api_key = QWEN_API_KEY
|
||
|
||
# 1. 获取所有启用的热词
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
cursor.execute("SELECT text, weight, lang FROM hot_words WHERE status = 1")
|
||
hot_words = cursor.fetchall()
|
||
cursor.close()
|
||
|
||
# 2. 获取现有的 vocabulary_id
|
||
existing_vocab_id = SystemConfigService.get_asr_vocabulary_id()
|
||
|
||
# 构建热词列表
|
||
vocabulary_list = [{"text": hw['text'], "weight": hw['weight'], "lang": hw['lang']} for hw in hot_words]
|
||
|
||
if not vocabulary_list:
|
||
return create_api_response(code="400", message="没有启用的热词可同步")
|
||
|
||
# 3. 调用阿里云 API
|
||
service = VocabularyService()
|
||
vocab_id = existing_vocab_id
|
||
|
||
try:
|
||
if existing_vocab_id:
|
||
# 尝试更新现有的热词表
|
||
try:
|
||
service.update_vocabulary(
|
||
vocabulary_id=existing_vocab_id,
|
||
vocabulary=vocabulary_list
|
||
)
|
||
# 更新成功,保持原有ID
|
||
except Exception as update_error:
|
||
# 如果更新失败(如资源不存在),尝试创建新的
|
||
print(f"Update vocabulary failed: {update_error}, trying to create new one.")
|
||
existing_vocab_id = None # 重置,触发创建逻辑
|
||
|
||
if not existing_vocab_id:
|
||
# 创建新的热词表
|
||
vocab_id = service.create_vocabulary(
|
||
prefix='imeeting',
|
||
target_model='paraformer-v2',
|
||
vocabulary=vocabulary_list
|
||
)
|
||
|
||
except Exception as api_error:
|
||
return create_api_response(code="500", message=f"同步到阿里云失败: {str(api_error)}")
|
||
|
||
# 4. 更新数据库中的 vocabulary_id
|
||
if vocab_id:
|
||
SystemConfigService.set_config(
|
||
SystemConfigService.ASR_VOCABULARY_ID,
|
||
vocab_id
|
||
)
|
||
|
||
return create_api_response(code="200", message="同步成功", data={"vocabulary_id": vocab_id})
|
||
|
||
except Exception as e:
|
||
return create_api_response(code="500", message=f"同步异常: {str(e)}")
|