115 lines
4.9 KiB
Python
115 lines
4.9 KiB
Python
from fastapi import APIRouter, Depends
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
|
|
from app.core.auth import get_current_user
|
|
from app.core.database import get_db_connection
|
|
from app.core.response import create_api_response
|
|
|
|
router = APIRouter()
|
|
|
|
# Pydantic Models
|
|
class PromptIn(BaseModel):
|
|
name: str
|
|
tags: Optional[str] = ""
|
|
content: str
|
|
|
|
class PromptOut(PromptIn):
|
|
id: int
|
|
created_at: str
|
|
|
|
class PromptListResponse(BaseModel):
|
|
prompts: List[PromptOut]
|
|
total: int
|
|
|
|
@router.post("/prompts")
|
|
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
|
"""Create a new prompt."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
try:
|
|
cursor.execute(
|
|
"INSERT INTO prompts (name, tags, content, creator_id) VALUES (%s, %s, %s, %s)",
|
|
(prompt.name, prompt.tags, prompt.content, current_user["user_id"])
|
|
)
|
|
connection.commit()
|
|
new_id = cursor.lastrowid
|
|
return create_api_response(code="200", message="提示词创建成功", data={"id": new_id, **prompt.dict()})
|
|
except Exception as e:
|
|
if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e):
|
|
return create_api_response(code="400", message="提示词名称已存在")
|
|
return create_api_response(code="500", message=f"创建提示词失败: {e}")
|
|
|
|
@router.get("/prompts")
|
|
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_user)):
|
|
"""Get a paginated list of prompts filtered by current user."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
# 只获取当前用户创建的提示词
|
|
cursor.execute(
|
|
"SELECT COUNT(*) as total FROM prompts WHERE creator_id = %s",
|
|
(current_user["user_id"],)
|
|
)
|
|
total = cursor.fetchone()['total']
|
|
|
|
offset = (page - 1) * size
|
|
cursor.execute(
|
|
"SELECT id, name, tags, content, created_at FROM prompts WHERE creator_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s",
|
|
(current_user["user_id"], size, offset)
|
|
)
|
|
prompts = cursor.fetchall()
|
|
return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total})
|
|
|
|
@router.get("/prompts/{prompt_id}")
|
|
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
|
"""Get a single prompt by its ID."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
cursor.execute("SELECT id, name, tags, content, created_at FROM prompts WHERE id = %s", (prompt_id,))
|
|
prompt = cursor.fetchone()
|
|
if not prompt:
|
|
return create_api_response(code="404", message="提示词不存在")
|
|
return create_api_response(code="200", message="获取提示词成功", data=prompt)
|
|
|
|
@router.put("/prompts/{prompt_id}")
|
|
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
|
"""Update an existing prompt."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
try:
|
|
cursor.execute(
|
|
"UPDATE prompts SET name = %s, tags = %s, content = %s WHERE id = %s",
|
|
(prompt.name, prompt.tags, prompt.content, prompt_id)
|
|
)
|
|
if cursor.rowcount == 0:
|
|
return create_api_response(code="404", message="提示词不存在")
|
|
connection.commit()
|
|
return create_api_response(code="200", message="提示词更新成功")
|
|
except Exception as e:
|
|
if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e):
|
|
return create_api_response(code="400", message="提示词名称已存在")
|
|
return create_api_response(code="500", message=f"更新提示词失败: {e}")
|
|
|
|
@router.delete("/prompts/{prompt_id}")
|
|
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
|
"""Delete a prompt. Only the creator can delete their own prompts."""
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
# 首先检查提示词是否存在以及是否属于当前用户
|
|
cursor.execute(
|
|
"SELECT creator_id FROM prompts WHERE id = %s",
|
|
(prompt_id,)
|
|
)
|
|
prompt = cursor.fetchone()
|
|
|
|
if not prompt:
|
|
return create_api_response(code="404", message="提示词不存在")
|
|
|
|
if prompt['creator_id'] != current_user["user_id"]:
|
|
return create_api_response(code="403", message="无权删除其他用户的提示词")
|
|
|
|
# 删除提示词
|
|
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
|
|
connection.commit()
|
|
return create_api_response(code="200", message="提示词删除成功")
|