210 lines
8.5 KiB
Python
210 lines
8.5 KiB
Python
from fastapi import APIRouter, Depends
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
|
||
from app.core.auth import get_current_user
|
||
from app.core.database import get_db_connection
|
||
from app.core.response import create_api_response
|
||
|
||
router = APIRouter()
|
||
|
||
# Pydantic Models
|
||
class PromptIn(BaseModel):
|
||
name: str
|
||
task_type: str # 'MEETING_TASK' 或 'KNOWLEDGE_TASK'
|
||
content: str
|
||
is_default: bool = False
|
||
is_active: bool = True
|
||
|
||
class PromptOut(PromptIn):
|
||
id: int
|
||
creator_id: int
|
||
created_at: str
|
||
|
||
class PromptListResponse(BaseModel):
|
||
prompts: List[PromptOut]
|
||
total: int
|
||
|
||
@router.post("/prompts")
|
||
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
||
"""Create a new prompt."""
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
try:
|
||
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
|
||
if prompt.is_default:
|
||
cursor.execute(
|
||
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s",
|
||
(prompt.task_type,)
|
||
)
|
||
|
||
cursor.execute(
|
||
"""INSERT INTO prompts (name, task_type, content, is_default, is_active, creator_id)
|
||
VALUES (%s, %s, %s, %s, %s, %s)""",
|
||
(prompt.name, prompt.task_type, prompt.content, prompt.is_default,
|
||
prompt.is_active, current_user["user_id"])
|
||
)
|
||
connection.commit()
|
||
new_id = cursor.lastrowid
|
||
return create_api_response(
|
||
code="200",
|
||
message="提示词创建成功",
|
||
data={"id": new_id, **prompt.dict()}
|
||
)
|
||
except Exception as e:
|
||
if "Duplicate entry" in str(e):
|
||
return create_api_response(code="400", message="提示词名称已存在")
|
||
return create_api_response(code="500", message=f"创建提示词失败: {e}")
|
||
|
||
@router.get("/prompts/active/{task_type}")
|
||
def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)):
|
||
"""Get all active prompts for a specific task type."""
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
cursor.execute(
|
||
"""SELECT id, name, is_default
|
||
FROM prompts
|
||
WHERE task_type = %s AND is_active = TRUE
|
||
ORDER BY is_default DESC, created_at DESC""",
|
||
(task_type,)
|
||
)
|
||
prompts = cursor.fetchall()
|
||
return create_api_response(
|
||
code="200",
|
||
message="获取启用模版列表成功",
|
||
data={"prompts": prompts}
|
||
)
|
||
|
||
@router.get("/prompts")
|
||
def get_prompts(
|
||
task_type: Optional[str] = None,
|
||
page: int = 1,
|
||
size: int = 50,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""Get a paginated list of prompts filtered by current user and optionally by task_type."""
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# 构建 WHERE 条件
|
||
where_conditions = ["creator_id = %s"]
|
||
params = [current_user["user_id"]]
|
||
|
||
if task_type:
|
||
where_conditions.append("task_type = %s")
|
||
params.append(task_type)
|
||
|
||
where_clause = " AND ".join(where_conditions)
|
||
|
||
# 获取总数
|
||
cursor.execute(
|
||
f"SELECT COUNT(*) as total FROM prompts WHERE {where_clause}",
|
||
tuple(params)
|
||
)
|
||
total = cursor.fetchone()['total']
|
||
|
||
# 获取分页数据
|
||
offset = (page - 1) * size
|
||
cursor.execute(
|
||
f"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at
|
||
FROM prompts
|
||
WHERE {where_clause}
|
||
ORDER BY created_at DESC
|
||
LIMIT %s OFFSET %s""",
|
||
tuple(params + [size, offset])
|
||
)
|
||
prompts = cursor.fetchall()
|
||
return create_api_response(
|
||
code="200",
|
||
message="获取提示词列表成功",
|
||
data={"prompts": prompts, "total": total}
|
||
)
|
||
|
||
@router.get("/prompts/{prompt_id}")
|
||
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
||
"""Get a single prompt by its ID."""
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
cursor.execute(
|
||
"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at
|
||
FROM prompts WHERE id = %s""",
|
||
(prompt_id,)
|
||
)
|
||
prompt = cursor.fetchone()
|
||
if not prompt:
|
||
return create_api_response(code="404", message="提示词不存在")
|
||
return create_api_response(code="200", message="获取提示词成功", data=prompt)
|
||
|
||
@router.put("/prompts/{prompt_id}")
|
||
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
||
"""Update an existing prompt."""
|
||
print(f"[UPDATE PROMPT] prompt_id={prompt_id}, type={type(prompt_id)}")
|
||
print(f"[UPDATE PROMPT] user_id={current_user['user_id']}")
|
||
print(f"[UPDATE PROMPT] data: name={prompt.name}, task_type={prompt.task_type}, content_len={len(prompt.content)}, is_default={prompt.is_default}, is_active={prompt.is_active}")
|
||
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
try:
|
||
# 先检查记录是否存在
|
||
cursor.execute("SELECT id, creator_id FROM prompts WHERE id = %s", (prompt_id,))
|
||
existing = cursor.fetchone()
|
||
print(f"[UPDATE PROMPT] existing record: {existing}")
|
||
|
||
if not existing:
|
||
print(f"[UPDATE PROMPT] Prompt {prompt_id} not found in database")
|
||
return create_api_response(code="404", message="提示词不存在")
|
||
|
||
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
|
||
if prompt.is_default:
|
||
print(f"[UPDATE PROMPT] Setting as default, clearing other defaults for task_type={prompt.task_type}")
|
||
cursor.execute(
|
||
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s AND id != %s",
|
||
(prompt.task_type, prompt_id)
|
||
)
|
||
print(f"[UPDATE PROMPT] Cleared {cursor.rowcount} other default prompts")
|
||
|
||
print(f"[UPDATE PROMPT] Executing UPDATE query")
|
||
cursor.execute(
|
||
"""UPDATE prompts
|
||
SET name = %s, task_type = %s, content = %s, is_default = %s, is_active = %s
|
||
WHERE id = %s""",
|
||
(prompt.name, prompt.task_type, prompt.content, prompt.is_default,
|
||
prompt.is_active, prompt_id)
|
||
)
|
||
rows_affected = cursor.rowcount
|
||
print(f"[UPDATE PROMPT] UPDATE affected {rows_affected} rows (0 means no changes needed)")
|
||
|
||
# 注意:rowcount=0 不代表记录不存在,可能是所有字段值都相同
|
||
# 我们已经在上面确认了记录存在,所以这里直接提交即可
|
||
connection.commit()
|
||
print(f"[UPDATE PROMPT] Success! Committed changes")
|
||
return create_api_response(code="200", message="提示词更新成功")
|
||
except Exception as e:
|
||
print(f"[UPDATE PROMPT] Exception: {type(e).__name__}: {e}")
|
||
if "Duplicate entry" in str(e):
|
||
return create_api_response(code="400", message="提示词名称已存在")
|
||
return create_api_response(code="500", message=f"更新提示词失败: {e}")
|
||
|
||
@router.delete("/prompts/{prompt_id}")
|
||
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
||
"""Delete a prompt. Only the creator can delete their own prompts."""
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
# 首先检查提示词是否存在以及是否属于当前用户
|
||
cursor.execute(
|
||
"SELECT creator_id FROM prompts WHERE id = %s",
|
||
(prompt_id,)
|
||
)
|
||
prompt = cursor.fetchone()
|
||
|
||
if not prompt:
|
||
return create_api_response(code="404", message="提示词不存在")
|
||
|
||
if prompt['creator_id'] != current_user["user_id"]:
|
||
return create_api_response(code="403", message="无权删除其他用户的提示词")
|
||
|
||
# 删除提示词
|
||
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
|
||
connection.commit()
|
||
return create_api_response(code="200", message="提示词删除成功")
|