nex_basse/backend/app/api/v1/endpoints/hotwords.py

73 lines
2.1 KiB
Python

from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from pypinyin import pinyin, Style
from app.core.deps import get_db
from app.models.hotword import Hotword
from app.schemas.hotword import Hotword as HotwordSchema, HotwordCreate, HotwordUpdate
router = APIRouter()
@router.get("", response_model=List[HotwordSchema])
def list_hotwords(db: Session = Depends(get_db)):
return db.query(Hotword).all()
@router.post("", response_model=HotwordSchema)
def create_hotword(
hotword_in: HotwordCreate,
db: Session = Depends(get_db)
):
# Auto generate pinyin
py_list = pinyin(hotword_in.word, style=Style.NORMAL)
# py_list is like [['zhong'], ['guo']]
generated_pinyin = " ".join([item[0] for item in py_list])
db_obj = Hotword(
word=hotword_in.word,
pinyin=generated_pinyin,
weight=hotword_in.weight,
scope="global"
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
@router.put("/{hotword_id}", response_model=HotwordSchema)
def update_hotword(
hotword_id: int,
hotword_in: HotwordUpdate,
db: Session = Depends(get_db)
):
hotword = db.query(Hotword).filter(Hotword.id == hotword_id).first()
if not hotword:
raise HTTPException(status_code=404, detail="Hotword not found")
update_data = hotword_in.dict(exclude_unset=True)
# If word is updated, regenerate pinyin
if "word" in update_data and update_data["word"]:
py_list = pinyin(update_data["word"], style=Style.NORMAL)
update_data["pinyin"] = " ".join([item[0] for item in py_list])
for field, value in update_data.items():
setattr(hotword, field, value)
db.add(hotword)
db.commit()
db.refresh(hotword)
return hotword
@router.delete("/{hotword_id}")
def delete_hotword(
hotword_id: int,
db: Session = Depends(get_db)
):
hotword = db.query(Hotword).filter(Hotword.id == hotword_id).first()
if not hotword:
raise HTTPException(status_code=404, detail="Hotword not found")
db.delete(hotword)
db.commit()
return {"status": "success"}