99 lines
4.3 KiB
Python
99 lines
4.3 KiB
Python
from fastapi import APIRouter, Depends
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
|
|
from app.core.auth import get_current_admin_user
|
|
from app.core.database import get_db_connection
|
|
from app.core.response import create_api_response
|
|
|
|
router = APIRouter()
|
|
|
|
# Pydantic Models
|
|
class PromptIn(BaseModel):
|
|
name: str
|
|
tags: Optional[str] = ""
|
|
content: str
|
|
|
|
class PromptOut(PromptIn):
|
|
id: int
|
|
created_at: str
|
|
|
|
class PromptListResponse(BaseModel):
|
|
prompts: List[PromptOut]
|
|
total: int
|
|
|
|
@router.post("/prompts")
|
|
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)):
|
|
"""Create a new prompt."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
try:
|
|
cursor.execute(
|
|
"INSERT INTO prompts (name, tags, content) VALUES (%s, %s, %s)",
|
|
(prompt.name, prompt.tags, prompt.content)
|
|
)
|
|
connection.commit()
|
|
new_id = cursor.lastrowid
|
|
return create_api_response(code="200", message="提示词创建成功", data={"id": new_id, **prompt.dict()})
|
|
except Exception as e:
|
|
if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e):
|
|
return create_api_response(code="400", message="提示词名称已存在")
|
|
return create_api_response(code="500", message=f"创建提示词失败: {e}")
|
|
|
|
@router.get("/prompts")
|
|
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_admin_user)):
|
|
"""Get a paginated list of prompts."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
cursor.execute("SELECT COUNT(*) as total FROM prompts")
|
|
total = cursor.fetchone()['total']
|
|
|
|
offset = (page - 1) * size
|
|
cursor.execute(
|
|
"SELECT id, name, tags, content, created_at FROM prompts ORDER BY created_at DESC LIMIT %s OFFSET %s",
|
|
(size, offset)
|
|
)
|
|
prompts = cursor.fetchall()
|
|
return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total})
|
|
|
|
@router.get("/prompts/{prompt_id}")
|
|
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)):
|
|
"""Get a single prompt by its ID."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
cursor.execute("SELECT id, name, tags, content, created_at FROM prompts WHERE id = %s", (prompt_id,))
|
|
prompt = cursor.fetchone()
|
|
if not prompt:
|
|
return create_api_response(code="404", message="提示词不存在")
|
|
return create_api_response(code="200", message="获取提示词成功", data=prompt)
|
|
|
|
@router.put("/prompts/{prompt_id}")
|
|
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)):
|
|
"""Update an existing prompt."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
try:
|
|
cursor.execute(
|
|
"UPDATE prompts SET name = %s, tags = %s, content = %s WHERE id = %s",
|
|
(prompt.name, prompt.tags, prompt.content, prompt_id)
|
|
)
|
|
if cursor.rowcount == 0:
|
|
return create_api_response(code="404", message="提示词不存在")
|
|
connection.commit()
|
|
return create_api_response(code="200", message="提示词更新成功")
|
|
except Exception as e:
|
|
if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e):
|
|
return create_api_response(code="400", message="提示词名称已存在")
|
|
return create_api_response(code="500", message=f"更新提示词失败: {e}")
|
|
|
|
@router.delete("/prompts/{prompt_id}")
|
|
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)):
|
|
"""Delete a prompt."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor()
|
|
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
|
|
if cursor.rowcount == 0:
|
|
return create_api_response(code="404", message="提示词不存在")
|
|
connection.commit()
|
|
return create_api_response(code="200", message="提示词删除成功")
|