imetting_backend/app/api/endpoints/prompts.py

115 lines
4.9 KiB
Python

from fastapi import APIRouter, Depends
from pydantic import BaseModel
from typing import List, Optional
from app.core.auth import get_current_user
from app.core.database import get_db_connection
from app.core.response import create_api_response
router = APIRouter()
# Pydantic Models
class PromptIn(BaseModel):
name: str
tags: Optional[str] = ""
content: str
class PromptOut(PromptIn):
id: int
created_at: str
class PromptListResponse(BaseModel):
prompts: List[PromptOut]
total: int
@router.post("/prompts")
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
"""Create a new prompt."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
cursor.execute(
"INSERT INTO prompts (name, tags, content, creator_id) VALUES (%s, %s, %s, %s)",
(prompt.name, prompt.tags, prompt.content, current_user["user_id"])
)
connection.commit()
new_id = cursor.lastrowid
return create_api_response(code="200", message="提示词创建成功", data={"id": new_id, **prompt.dict()})
except Exception as e:
if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e):
return create_api_response(code="400", message="提示词名称已存在")
return create_api_response(code="500", message=f"创建提示词失败: {e}")
@router.get("/prompts")
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_user)):
"""Get a paginated list of prompts filtered by current user."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 只获取当前用户创建的提示词
cursor.execute(
"SELECT COUNT(*) as total FROM prompts WHERE creator_id = %s",
(current_user["user_id"],)
)
total = cursor.fetchone()['total']
offset = (page - 1) * size
cursor.execute(
"SELECT id, name, tags, content, created_at FROM prompts WHERE creator_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s",
(current_user["user_id"], size, offset)
)
prompts = cursor.fetchall()
return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total})
@router.get("/prompts/{prompt_id}")
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
"""Get a single prompt by its ID."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute("SELECT id, name, tags, content, created_at FROM prompts WHERE id = %s", (prompt_id,))
prompt = cursor.fetchone()
if not prompt:
return create_api_response(code="404", message="提示词不存在")
return create_api_response(code="200", message="获取提示词成功", data=prompt)
@router.put("/prompts/{prompt_id}")
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
"""Update an existing prompt."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
cursor.execute(
"UPDATE prompts SET name = %s, tags = %s, content = %s WHERE id = %s",
(prompt.name, prompt.tags, prompt.content, prompt_id)
)
if cursor.rowcount == 0:
return create_api_response(code="404", message="提示词不存在")
connection.commit()
return create_api_response(code="200", message="提示词更新成功")
except Exception as e:
if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e):
return create_api_response(code="400", message="提示词名称已存在")
return create_api_response(code="500", message=f"更新提示词失败: {e}")
@router.delete("/prompts/{prompt_id}")
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
"""Delete a prompt. Only the creator can delete their own prompts."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 首先检查提示词是否存在以及是否属于当前用户
cursor.execute(
"SELECT creator_id FROM prompts WHERE id = %s",
(prompt_id,)
)
prompt = cursor.fetchone()
if not prompt:
return create_api_response(code="404", message="提示词不存在")
if prompt['creator_id'] != current_user["user_id"]:
return create_api_response(code="403", message="无权删除其他用户的提示词")
# 删除提示词
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
connection.commit()
return create_api_response(code="200", message="提示词删除成功")