imetting_backend/app/api/endpoints/prompts.py

210 lines
8.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastapi import APIRouter, Depends
from pydantic import BaseModel
from typing import List, Optional
from app.core.auth import get_current_user
from app.core.database import get_db_connection
from app.core.response import create_api_response
router = APIRouter()
# Pydantic Models
class PromptIn(BaseModel):
name: str
task_type: str # 'MEETING_TASK' 或 'KNOWLEDGE_TASK'
content: str
is_default: bool = False
is_active: bool = True
class PromptOut(PromptIn):
id: int
creator_id: int
created_at: str
class PromptListResponse(BaseModel):
prompts: List[PromptOut]
total: int
@router.post("/prompts")
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
"""Create a new prompt."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
if prompt.is_default:
cursor.execute(
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s",
(prompt.task_type,)
)
cursor.execute(
"""INSERT INTO prompts (name, task_type, content, is_default, is_active, creator_id)
VALUES (%s, %s, %s, %s, %s, %s)""",
(prompt.name, prompt.task_type, prompt.content, prompt.is_default,
prompt.is_active, current_user["user_id"])
)
connection.commit()
new_id = cursor.lastrowid
return create_api_response(
code="200",
message="提示词创建成功",
data={"id": new_id, **prompt.dict()}
)
except Exception as e:
if "Duplicate entry" in str(e):
return create_api_response(code="400", message="提示词名称已存在")
return create_api_response(code="500", message=f"创建提示词失败: {e}")
@router.get("/prompts/active/{task_type}")
def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)):
"""Get all active prompts for a specific task type."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(
"""SELECT id, name, is_default
FROM prompts
WHERE task_type = %s AND is_active = TRUE
ORDER BY is_default DESC, created_at DESC""",
(task_type,)
)
prompts = cursor.fetchall()
return create_api_response(
code="200",
message="获取启用模版列表成功",
data={"prompts": prompts}
)
@router.get("/prompts")
def get_prompts(
task_type: Optional[str] = None,
page: int = 1,
size: int = 50,
current_user: dict = Depends(get_current_user)
):
"""Get a paginated list of prompts filtered by current user and optionally by task_type."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 构建 WHERE 条件
where_conditions = ["creator_id = %s"]
params = [current_user["user_id"]]
if task_type:
where_conditions.append("task_type = %s")
params.append(task_type)
where_clause = " AND ".join(where_conditions)
# 获取总数
cursor.execute(
f"SELECT COUNT(*) as total FROM prompts WHERE {where_clause}",
tuple(params)
)
total = cursor.fetchone()['total']
# 获取分页数据
offset = (page - 1) * size
cursor.execute(
f"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at
FROM prompts
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT %s OFFSET %s""",
tuple(params + [size, offset])
)
prompts = cursor.fetchall()
return create_api_response(
code="200",
message="获取提示词列表成功",
data={"prompts": prompts, "total": total}
)
@router.get("/prompts/{prompt_id}")
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
"""Get a single prompt by its ID."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(
"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at
FROM prompts WHERE id = %s""",
(prompt_id,)
)
prompt = cursor.fetchone()
if not prompt:
return create_api_response(code="404", message="提示词不存在")
return create_api_response(code="200", message="获取提示词成功", data=prompt)
@router.put("/prompts/{prompt_id}")
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
"""Update an existing prompt."""
print(f"[UPDATE PROMPT] prompt_id={prompt_id}, type={type(prompt_id)}")
print(f"[UPDATE PROMPT] user_id={current_user['user_id']}")
print(f"[UPDATE PROMPT] data: name={prompt.name}, task_type={prompt.task_type}, content_len={len(prompt.content)}, is_default={prompt.is_default}, is_active={prompt.is_active}")
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
# 先检查记录是否存在
cursor.execute("SELECT id, creator_id FROM prompts WHERE id = %s", (prompt_id,))
existing = cursor.fetchone()
print(f"[UPDATE PROMPT] existing record: {existing}")
if not existing:
print(f"[UPDATE PROMPT] Prompt {prompt_id} not found in database")
return create_api_response(code="404", message="提示词不存在")
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
if prompt.is_default:
print(f"[UPDATE PROMPT] Setting as default, clearing other defaults for task_type={prompt.task_type}")
cursor.execute(
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s AND id != %s",
(prompt.task_type, prompt_id)
)
print(f"[UPDATE PROMPT] Cleared {cursor.rowcount} other default prompts")
print(f"[UPDATE PROMPT] Executing UPDATE query")
cursor.execute(
"""UPDATE prompts
SET name = %s, task_type = %s, content = %s, is_default = %s, is_active = %s
WHERE id = %s""",
(prompt.name, prompt.task_type, prompt.content, prompt.is_default,
prompt.is_active, prompt_id)
)
rows_affected = cursor.rowcount
print(f"[UPDATE PROMPT] UPDATE affected {rows_affected} rows (0 means no changes needed)")
# 注意rowcount=0 不代表记录不存在,可能是所有字段值都相同
# 我们已经在上面确认了记录存在,所以这里直接提交即可
connection.commit()
print(f"[UPDATE PROMPT] Success! Committed changes")
return create_api_response(code="200", message="提示词更新成功")
except Exception as e:
print(f"[UPDATE PROMPT] Exception: {type(e).__name__}: {e}")
if "Duplicate entry" in str(e):
return create_api_response(code="400", message="提示词名称已存在")
return create_api_response(code="500", message=f"更新提示词失败: {e}")
@router.delete("/prompts/{prompt_id}")
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
"""Delete a prompt. Only the creator can delete their own prompts."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 首先检查提示词是否存在以及是否属于当前用户
cursor.execute(
"SELECT creator_id FROM prompts WHERE id = %s",
(prompt_id,)
)
prompt = cursor.fetchone()
if not prompt:
return create_api_response(code="404", message="提示词不存在")
if prompt['creator_id'] != current_user["user_id"]:
return create_api_response(code="403", message="无权删除其他用户的提示词")
# 删除提示词
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
connection.commit()
return create_api_response(code="200", message="提示词删除成功")