from fastapi import APIRouter, Depends from pydantic import BaseModel from typing import List, Optional from app.core.auth import get_current_user from app.core.database import get_db_connection from app.core.response import create_api_response router = APIRouter() # Pydantic Models class PromptIn(BaseModel): name: str task_type: str # 'MEETING_TASK' 或 'KNOWLEDGE_TASK' content: str is_default: bool = False is_active: bool = True class PromptOut(PromptIn): id: int creator_id: int created_at: str class PromptListResponse(BaseModel): prompts: List[PromptOut] total: int @router.post("/prompts") def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)): """Create a new prompt.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) try: # 如果设置为默认,需要先取消同类型其他提示词的默认状态 if prompt.is_default: cursor.execute( "UPDATE prompts SET is_default = FALSE WHERE task_type = %s", (prompt.task_type,) ) cursor.execute( """INSERT INTO prompts (name, task_type, content, is_default, is_active, creator_id) VALUES (%s, %s, %s, %s, %s, %s)""", (prompt.name, prompt.task_type, prompt.content, prompt.is_default, prompt.is_active, current_user["user_id"]) ) connection.commit() new_id = cursor.lastrowid return create_api_response( code="200", message="提示词创建成功", data={"id": new_id, **prompt.dict()} ) except Exception as e: if "Duplicate entry" in str(e): return create_api_response(code="400", message="提示词名称已存在") return create_api_response(code="500", message=f"创建提示词失败: {e}") @router.get("/prompts/active/{task_type}") def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)): """Get all active prompts for a specific task type.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) cursor.execute( """SELECT id, name, is_default FROM prompts WHERE task_type = %s AND is_active = TRUE ORDER BY is_default DESC, created_at DESC""", (task_type,) ) prompts = cursor.fetchall() return create_api_response( code="200", message="获取启用模版列表成功", data={"prompts": prompts} ) @router.get("/prompts") def get_prompts( task_type: Optional[str] = None, page: int = 1, size: int = 50, current_user: dict = Depends(get_current_user) ): """Get a paginated list of prompts filtered by current user and optionally by task_type.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) # 构建 WHERE 条件 where_conditions = ["creator_id = %s"] params = [current_user["user_id"]] if task_type: where_conditions.append("task_type = %s") params.append(task_type) where_clause = " AND ".join(where_conditions) # 获取总数 cursor.execute( f"SELECT COUNT(*) as total FROM prompts WHERE {where_clause}", tuple(params) ) total = cursor.fetchone()['total'] # 获取分页数据 offset = (page - 1) * size cursor.execute( f"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at FROM prompts WHERE {where_clause} ORDER BY created_at DESC LIMIT %s OFFSET %s""", tuple(params + [size, offset]) ) prompts = cursor.fetchall() return create_api_response( code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total} ) @router.get("/prompts/{prompt_id}") def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)): """Get a single prompt by its ID.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) cursor.execute( """SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at FROM prompts WHERE id = %s""", (prompt_id,) ) prompt = cursor.fetchone() if not prompt: return create_api_response(code="404", message="提示词不存在") return create_api_response(code="200", message="获取提示词成功", data=prompt) @router.put("/prompts/{prompt_id}") def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)): """Update an existing prompt.""" print(f"[UPDATE PROMPT] prompt_id={prompt_id}, type={type(prompt_id)}") print(f"[UPDATE PROMPT] user_id={current_user['user_id']}") print(f"[UPDATE PROMPT] data: name={prompt.name}, task_type={prompt.task_type}, content_len={len(prompt.content)}, is_default={prompt.is_default}, is_active={prompt.is_active}") with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) try: # 先检查记录是否存在 cursor.execute("SELECT id, creator_id FROM prompts WHERE id = %s", (prompt_id,)) existing = cursor.fetchone() print(f"[UPDATE PROMPT] existing record: {existing}") if not existing: print(f"[UPDATE PROMPT] Prompt {prompt_id} not found in database") return create_api_response(code="404", message="提示词不存在") # 如果设置为默认,需要先取消同类型其他提示词的默认状态 if prompt.is_default: print(f"[UPDATE PROMPT] Setting as default, clearing other defaults for task_type={prompt.task_type}") cursor.execute( "UPDATE prompts SET is_default = FALSE WHERE task_type = %s AND id != %s", (prompt.task_type, prompt_id) ) print(f"[UPDATE PROMPT] Cleared {cursor.rowcount} other default prompts") print(f"[UPDATE PROMPT] Executing UPDATE query") cursor.execute( """UPDATE prompts SET name = %s, task_type = %s, content = %s, is_default = %s, is_active = %s WHERE id = %s""", (prompt.name, prompt.task_type, prompt.content, prompt.is_default, prompt.is_active, prompt_id) ) rows_affected = cursor.rowcount print(f"[UPDATE PROMPT] UPDATE affected {rows_affected} rows (0 means no changes needed)") # 注意:rowcount=0 不代表记录不存在,可能是所有字段值都相同 # 我们已经在上面确认了记录存在,所以这里直接提交即可 connection.commit() print(f"[UPDATE PROMPT] Success! Committed changes") return create_api_response(code="200", message="提示词更新成功") except Exception as e: print(f"[UPDATE PROMPT] Exception: {type(e).__name__}: {e}") if "Duplicate entry" in str(e): return create_api_response(code="400", message="提示词名称已存在") return create_api_response(code="500", message=f"更新提示词失败: {e}") @router.delete("/prompts/{prompt_id}") def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)): """Delete a prompt. Only the creator can delete their own prompts.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) # 首先检查提示词是否存在以及是否属于当前用户 cursor.execute( "SELECT creator_id FROM prompts WHERE id = %s", (prompt_id,) ) prompt = cursor.fetchone() if not prompt: return create_api_response(code="404", message="提示词不存在") if prompt['creator_id'] != current_user["user_id"]: return create_api_response(code="403", message="无权删除其他用户的提示词") # 删除提示词 cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,)) connection.commit() return create_api_response(code="200", message="提示词删除成功")