nex_basse/backend/app/api/v1/endpoints/prompts.py

199 lines
6.7 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import or_, and_
from app.core.db import get_db
from app.core.deps import get_current_user
from app.models import User, PromptTemplate, UserPromptConfig
from app.schemas.prompt import (
PromptTemplateOut,
PromptTemplateCreate,
PromptTemplateUpdate,
UserPromptConfigUpdate
)
from typing import List, Optional
router = APIRouter(prefix="/prompts", tags=["prompts"])
def is_admin(user: User):
role_codes = [ur.role.role_code for ur in user.roles] if hasattr(user, 'roles') else []
return "admin" in role_codes or "superuser" in role_codes
@router.get("", response_model=List[PromptTemplateOut])
def list_prompts(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
keyword: Optional[str] = Query(None),
category: Optional[str] = Query(None),
scope: str = Query("personal", description="system or personal")
):
"""
根据 scope 返回不同范围的模板
"""
filters = []
if scope == "system":
# 系统管理入口:仅展示系统级模板
if not is_admin(current_user):
raise HTTPException(status_code=403, detail="无权访问系统模板库")
filters.append(PromptTemplate.is_system == 1)
else:
# 个人管理入口:展示已发布的系统模板 + 自己的个人模板
accessibility_filter = or_(
and_(PromptTemplate.is_system == 1, PromptTemplate.status == 1),
PromptTemplate.user_id == current_user.user_id
)
filters.append(accessibility_filter)
if keyword:
filters.append(or_(
PromptTemplate.name.contains(keyword),
PromptTemplate.description.contains(keyword),
PromptTemplate.content.contains(keyword)
))
if category:
filters.append(PromptTemplate.category == category)
query = db.query(
PromptTemplate,
UserPromptConfig.is_active,
UserPromptConfig.user_sort_order
).outerjoin(
UserPromptConfig,
(UserPromptConfig.template_id == PromptTemplate.id) & (UserPromptConfig.user_id == current_user.user_id)
).filter(and_(*filters))
results = query.all()
out = []
for template, is_active, user_sort_order in results:
item = PromptTemplateOut.model_validate(template)
item.is_active = is_active if is_active is not None else 1
item.user_sort_order = user_sort_order if user_sort_order is not None else template.sort_order
out.append(item)
# 排序:系统管理入口按全局排序,个人入口按个人排序
if scope == "system":
out.sort(key=lambda x: x.sort_order)
else:
out.sort(key=lambda x: (x.user_sort_order, x.sort_order))
return out
@router.post("", response_model=PromptTemplateOut)
def create_prompt(
payload: PromptTemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
user_is_admin = is_admin(current_user)
item_data = payload.model_dump()
# 如果显式指定为系统模板,必须是管理员
if item_data.get("is_system"):
if not user_is_admin:
raise HTTPException(status_code=403, detail="仅管理员可创建系统模板")
item_data["user_id"] = None
else:
item_data["user_id"] = current_user.user_id
item_data["is_system"] = 0
item = PromptTemplate(**item_data)
db.add(item)
db.commit()
db.refresh(item)
res = PromptTemplateOut.model_validate(item)
res.is_active = True
res.user_sort_order = item.sort_order
return res
@router.put("/{prompt_id}", response_model=PromptTemplateOut)
def update_prompt(
prompt_id: int,
payload: PromptTemplateUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
item = db.query(PromptTemplate).filter(PromptTemplate.id == prompt_id).first()
if not item:
raise HTTPException(status_code=404, detail="模板不存在")
user_is_admin = is_admin(current_user)
# 系统模板仅管理员可改
if item.is_system and not user_is_admin:
raise HTTPException(status_code=403, detail="无权修改系统模板")
# 个人模板仅主人可改
if not item.is_system and item.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="无权修改他人模板")
update_data = payload.model_dump(exclude_unset=True)
if "is_system" in update_data:
if user_is_admin:
new_val = update_data.pop("is_system")
item.is_system = new_val
item.user_id = None if new_val else current_user.user_id
else:
update_data.pop("is_system")
for k, v in update_data.items():
setattr(item, k, v)
db.commit()
db.refresh(item)
return PromptTemplateOut.model_validate(item)
@router.delete("/{prompt_id}")
def delete_prompt(
prompt_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
item = db.query(PromptTemplate).filter(PromptTemplate.id == prompt_id).first()
if not item:
raise HTTPException(status_code=404, detail="模板不存在")
user_is_admin = is_admin(current_user)
if item.is_system and not user_is_admin:
raise HTTPException(status_code=403, detail="无权删除系统模板")
if not item.is_system and item.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="无权删除他人模板")
db.delete(item)
db.commit()
return {"status": "ok"}
@router.patch("/{prompt_id}/config")
def update_user_config(
prompt_id: int,
payload: UserPromptConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
config = db.query(UserPromptConfig).filter(
UserPromptConfig.template_id == prompt_id,
UserPromptConfig.user_id == current_user.user_id
).first()
if not config:
template = db.query(PromptTemplate).filter(PromptTemplate.id == prompt_id).first()
if not template:
raise HTTPException(status_code=404, detail="模板不存在")
config = UserPromptConfig(
user_id=current_user.user_id,
template_id=prompt_id,
is_active=1, # Default to active (int)
user_sort_order=template.sort_order
)
db.add(config)
for k, v in payload.model_dump(exclude_unset=True).items():
# Convert boolean to int for SmallInteger fields
if isinstance(v, bool):
v = int(v)
setattr(config, k, v)
db.commit()
return {"status": "ok"}