345 lines
13 KiB
Python
345 lines
13 KiB
Python
"""
|
|
Backend-integrated MCP Streamable HTTP server.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
import hmac
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List
|
|
|
|
from fastapi import HTTPException, Response
|
|
from sqlalchemy import select
|
|
|
|
try:
|
|
from mcp.server.fastmcp import FastMCP
|
|
except ImportError: # pragma: no cover - runtime dependency
|
|
FastMCP = None
|
|
|
|
from app.core.database import AsyncSessionLocal
|
|
from app.models.mcp_bot import MCPBot
|
|
from app.models.project import Project, ProjectMember
|
|
from app.models.user import User
|
|
from app.schemas.project import ProjectResponse
|
|
from app.services.notification_service import notification_service
|
|
from app.services.search_service import search_service
|
|
from app.services.storage import storage_service
|
|
from app.services.log_service import log_service
|
|
from app.api.v1.projects import get_document_count
|
|
from app.api.v1.files import check_project_access
|
|
from app.core.config import settings
|
|
from app.core.enums import OperationType
|
|
from app.mcp.context import MCPRequestContext, current_mcp_request
|
|
|
|
|
|
mcp = (
|
|
FastMCP(
|
|
"NexDocs MCP",
|
|
host=settings.HOST,
|
|
port=settings.PORT,
|
|
stateless_http=True,
|
|
json_response=True,
|
|
streamable_http_path="/",
|
|
)
|
|
if FastMCP
|
|
else None
|
|
)
|
|
|
|
|
|
async def _get_current_user(db) -> User:
|
|
ctx = current_mcp_request.get()
|
|
if ctx is None:
|
|
raise RuntimeError("MCP request context is missing.")
|
|
|
|
result = await db.execute(select(User).where(User.id == ctx.user_id, User.status == 1))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise RuntimeError("Authenticated MCP user does not exist or is disabled.")
|
|
return user
|
|
|
|
|
|
async def _get_project_with_write_access(project_id: int, current_user: User, db):
|
|
return await check_project_access(project_id, current_user, db, require_write=True)
|
|
|
|
|
|
def _ensure_file_exists(file_path: Path, path: str) -> None:
|
|
if not file_path.exists():
|
|
raise HTTPException(status_code=404, detail=f"文件不存在: {path}")
|
|
if not file_path.is_file():
|
|
raise HTTPException(status_code=400, detail=f"目标不是文件: {path}")
|
|
|
|
|
|
def _ensure_file_not_exists(file_path: Path, path: str) -> None:
|
|
if file_path.exists():
|
|
raise HTTPException(status_code=400, detail=f"文件已存在: {path}")
|
|
|
|
|
|
async def _update_markdown_index(project_id: int, path: str, content: str) -> None:
|
|
if path.endswith(".md"):
|
|
await search_service.update_doc(project_id, path, Path(path).stem, content)
|
|
|
|
|
|
async def _remove_markdown_index(project_id: int, path: str) -> None:
|
|
if path.endswith(".md"):
|
|
await search_service.remove_doc(project_id, path)
|
|
|
|
|
|
if mcp is not None:
|
|
@mcp.tool(name="list_created_projects", description="Get projects created by the authenticated user.")
|
|
async def list_created_projects(keyword: str = "", limit: int = 100) -> List[Dict[str, Any]]:
|
|
async with AsyncSessionLocal() as db:
|
|
current_user = await _get_current_user(db)
|
|
result = await db.execute(
|
|
select(Project).where(Project.owner_id == current_user.id, Project.status == 1)
|
|
)
|
|
projects = result.scalars().all()
|
|
|
|
items = []
|
|
keyword_lower = keyword.strip().lower()
|
|
for project in projects:
|
|
project_dict = ProjectResponse.from_orm(project).dict()
|
|
project_dict["doc_count"] = get_document_count(project.storage_key)
|
|
if keyword_lower:
|
|
haystack = f"{project.name} {project.description or ''}".lower()
|
|
if keyword_lower not in haystack:
|
|
continue
|
|
items.append(project_dict)
|
|
|
|
return items[: max(limit, 0)]
|
|
|
|
|
|
@mcp.tool(name="get_project_tree", description="Get the directory tree of a specific project.")
|
|
async def get_project_tree(project_id: int) -> Dict[str, Any]:
|
|
async with AsyncSessionLocal() as db:
|
|
current_user = await _get_current_user(db)
|
|
project = await check_project_access(project_id, current_user, db)
|
|
project_root = storage_service.get_secure_path(project.storage_key)
|
|
tree = storage_service.generate_tree(project_root)
|
|
|
|
user_role = "owner"
|
|
if project.owner_id != current_user.id:
|
|
member_result = await db.execute(
|
|
select(ProjectMember).where(
|
|
ProjectMember.project_id == project_id,
|
|
ProjectMember.user_id == current_user.id,
|
|
)
|
|
)
|
|
member = member_result.scalar_one_or_none()
|
|
if member:
|
|
user_role = member.role
|
|
|
|
return {
|
|
"tree": [item.model_dump() for item in tree],
|
|
"user_role": user_role,
|
|
"project_name": project.name,
|
|
"project_description": project.description,
|
|
}
|
|
|
|
|
|
@mcp.tool(name="get_file", description="Read a file from a specific project.")
|
|
async def get_file(project_id: int, path: str) -> Dict[str, Any]:
|
|
async with AsyncSessionLocal() as db:
|
|
current_user = await _get_current_user(db)
|
|
project = await check_project_access(project_id, current_user, db)
|
|
file_path = storage_service.get_secure_path(project.storage_key, path)
|
|
_ensure_file_exists(file_path, path)
|
|
content = await storage_service.read_file(file_path)
|
|
return {"path": path, "content": content}
|
|
|
|
|
|
@mcp.tool(name="create_file", description="Create a new file in a specific project path.")
|
|
async def create_file(project_id: int, path: str, content: str = "") -> Dict[str, Any]:
|
|
async with AsyncSessionLocal() as db:
|
|
current_user = await _get_current_user(db)
|
|
project = await _get_project_with_write_access(project_id, current_user, db)
|
|
file_path = storage_service.get_secure_path(project.storage_key, path)
|
|
_ensure_file_not_exists(file_path, path)
|
|
await storage_service.write_file(file_path, content)
|
|
await _update_markdown_index(project_id, path, content)
|
|
|
|
await log_service.log_file_operation(
|
|
db=db,
|
|
operation_type=OperationType.CREATE_FILE,
|
|
project_id=project_id,
|
|
file_path=path,
|
|
user=current_user,
|
|
detail={"content_length": len(content), "source": "mcp"},
|
|
request=None,
|
|
)
|
|
|
|
await notification_service.notify_project_members(
|
|
db=db,
|
|
project_id=project_id,
|
|
exclude_user_id=current_user.id,
|
|
title="项目文档创建",
|
|
content=(
|
|
f"项目 [{project.name}] 中的文档 [{path}] "
|
|
f"已被 {current_user.nickname or current_user.username} 通过 MCP 创建。"
|
|
),
|
|
link=f"/projects/{project_id}/docs?file={path}",
|
|
category="project",
|
|
)
|
|
|
|
return {
|
|
"message": "文件创建成功",
|
|
"project_id": project_id,
|
|
"path": path,
|
|
}
|
|
|
|
|
|
@mcp.tool(name="update_file", description="Update an existing file in a specific project path.")
|
|
async def update_file(project_id: int, path: str, content: str) -> Dict[str, Any]:
|
|
async with AsyncSessionLocal() as db:
|
|
current_user = await _get_current_user(db)
|
|
project = await _get_project_with_write_access(project_id, current_user, db)
|
|
file_path = storage_service.get_secure_path(project.storage_key, path)
|
|
_ensure_file_exists(file_path, path)
|
|
await storage_service.write_file(file_path, content)
|
|
await _update_markdown_index(project_id, path, content)
|
|
|
|
await log_service.log_file_operation(
|
|
db=db,
|
|
operation_type=OperationType.SAVE_FILE,
|
|
project_id=project_id,
|
|
file_path=path,
|
|
user=current_user,
|
|
detail={"content_length": len(content), "source": "mcp"},
|
|
request=None,
|
|
)
|
|
|
|
await notification_service.notify_project_members(
|
|
db=db,
|
|
project_id=project_id,
|
|
exclude_user_id=current_user.id,
|
|
title="项目文档更新",
|
|
content=(
|
|
f"项目 [{project.name}] 中的文档 [{path}] "
|
|
f"已被 {current_user.nickname or current_user.username} 通过 MCP 更新。"
|
|
),
|
|
link=f"/projects/{project_id}/docs?file={path}",
|
|
category="project",
|
|
)
|
|
|
|
return {
|
|
"message": "文件更新成功",
|
|
"project_id": project_id,
|
|
"path": path,
|
|
}
|
|
|
|
|
|
@mcp.tool(name="delete_file", description="Delete an existing file in a specific project path.")
|
|
async def delete_file(project_id: int, path: str) -> Dict[str, Any]:
|
|
async with AsyncSessionLocal() as db:
|
|
current_user = await _get_current_user(db)
|
|
project = await _get_project_with_write_access(project_id, current_user, db)
|
|
file_path = storage_service.get_secure_path(project.storage_key, path)
|
|
_ensure_file_exists(file_path, path)
|
|
await storage_service.delete_file(file_path)
|
|
await _remove_markdown_index(project_id, path)
|
|
|
|
await log_service.log_file_operation(
|
|
db=db,
|
|
operation_type=OperationType.DELETE_FILE,
|
|
project_id=project_id,
|
|
file_path=path,
|
|
user=current_user,
|
|
detail={"source": "mcp"},
|
|
request=None,
|
|
)
|
|
|
|
await notification_service.notify_project_members(
|
|
db=db,
|
|
project_id=project_id,
|
|
exclude_user_id=current_user.id,
|
|
title="项目文档删除",
|
|
content=(
|
|
f"项目 [{project.name}] 中的文档 [{path}] "
|
|
f"已被 {current_user.nickname or current_user.username} 通过 MCP 删除。"
|
|
),
|
|
category="project",
|
|
)
|
|
|
|
return {
|
|
"message": "文件删除成功",
|
|
"project_id": project_id,
|
|
"path": path,
|
|
}
|
|
|
|
|
|
def create_mcp_http_app():
|
|
"""Return the MCP streamable HTTP ASGI app."""
|
|
if mcp is None:
|
|
raise RuntimeError("Package 'mcp' is required to run the MCP endpoint.")
|
|
return mcp.streamable_http_app()
|
|
|
|
|
|
def get_mcp_session_manager():
|
|
"""Return the MCP streamable HTTP session manager."""
|
|
if mcp is None:
|
|
raise RuntimeError("Package 'mcp' is required to run the MCP endpoint.")
|
|
return mcp.session_manager
|
|
|
|
|
|
class MCPHeaderAuthApp:
|
|
"""ASGI wrapper that authenticates incoming MCP requests via bot headers."""
|
|
|
|
def __init__(self, app):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
headers = {
|
|
key.decode("latin-1").lower(): value.decode("latin-1")
|
|
for key, value in scope.get("headers", [])
|
|
}
|
|
bot_id = headers.get("x-bot-id", "").strip()
|
|
bot_secret = headers.get("x-bot-secret", "").strip()
|
|
|
|
if not bot_id or not bot_secret:
|
|
response = Response(
|
|
content='{"error":"Missing X-Bot-Id or X-Bot-Secret"}',
|
|
status_code=401,
|
|
media_type="application/json",
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
result = await db.execute(
|
|
select(MCPBot, User)
|
|
.join(User, User.id == MCPBot.user_id)
|
|
.where(MCPBot.bot_id == bot_id, MCPBot.status == 1, User.status == 1)
|
|
)
|
|
row = result.first()
|
|
|
|
if not row:
|
|
response = Response(
|
|
content='{"error":"Invalid MCP bot"}',
|
|
status_code=403,
|
|
media_type="application/json",
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
mcp_bot, user = row
|
|
if not hmac.compare_digest(mcp_bot.bot_secret, bot_secret):
|
|
response = Response(
|
|
content='{"error":"Invalid MCP secret"}',
|
|
status_code=403,
|
|
media_type="application/json",
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
mcp_bot.last_used_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
token = current_mcp_request.set(MCPRequestContext(bot_id=bot_id, user_id=user.id))
|
|
try:
|
|
await self.app(scope, receive, send)
|
|
finally:
|
|
current_mcp_request.reset(token)
|