271 lines
11 KiB
Python
271 lines
11 KiB
Python
import sys
|
|
import uuid
|
|
import os
|
|
import re
|
|
import tempfile
|
|
import asyncio
|
|
import uvicorn
|
|
import click
|
|
import zipfile
|
|
from pathlib import Path
|
|
import glob
|
|
from fastapi import Depends, FastAPI, HTTPException, UploadFile, File, Form, APIRouter
|
|
from fastapi.middleware.gzip import GZipMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from starlette.background import BackgroundTask
|
|
from typing import List, Optional
|
|
from loguru import logger
|
|
from base64 import b64encode
|
|
|
|
# MinerU 内部导入
|
|
from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes, word_suffixes
|
|
from mineru.utils.cli_parser import arg_parse
|
|
from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path
|
|
from mineru.version import __version__
|
|
|
|
# --- 日志配置 ---
|
|
log_level = os.getenv("MINERU_LOG_LEVEL", "INFO").upper()
|
|
logger.remove()
|
|
logger.add(sys.stderr, level=log_level)
|
|
|
|
# --- 全局变量与辅助类 ---
|
|
_request_semaphore: Optional[asyncio.Semaphore] = None
|
|
|
|
|
|
class ProgressTracker:
|
|
def __init__(self):
|
|
self.progress = 0
|
|
self.status = "初始化"
|
|
|
|
def update(self, progress: float, status: str):
|
|
self.progress = progress
|
|
self.status = status
|
|
|
|
def get_progress(self):
|
|
return {"progress": self.progress, "status": self.status}
|
|
|
|
|
|
async def limit_concurrency():
|
|
if _request_semaphore is not None:
|
|
if _request_semaphore.locked():
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=f"Server is at maximum capacity: {os.getenv('MINERU_API_MAX_CONCURRENT_REQUESTS', '0')}."
|
|
)
|
|
async with _request_semaphore:
|
|
yield
|
|
else:
|
|
yield
|
|
|
|
|
|
def sanitize_filename(filename: str) -> str:
|
|
sanitized = re.sub(r'[/\\\.]{2,}|[/\\]', '', filename)
|
|
sanitized = re.sub(r'[^\w.-]', '_', sanitized, flags=re.UNICODE)
|
|
if sanitized.startswith('.'):
|
|
sanitized = '_' + sanitized[1:]
|
|
return sanitized or 'unnamed'
|
|
|
|
|
|
def cleanup_file(file_path: str) -> None:
|
|
try:
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
except Exception as e:
|
|
logger.warning(f"fail clean file {file_path}: {e}")
|
|
|
|
|
|
def encode_image(image_path: str) -> str:
|
|
with open(image_path, "rb") as f:
|
|
return b64encode(f.read()).decode()
|
|
|
|
|
|
def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]:
|
|
result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
|
|
if os.path.exists(result_file_path):
|
|
with open(result_file_path, "r", encoding="utf-8") as fp:
|
|
return fp.read()
|
|
return None
|
|
|
|
|
|
# --- 创建 API 路由组 ---
|
|
api_router = APIRouter(prefix="/api")
|
|
|
|
|
|
@api_router.post(path="/file_parse", dependencies=[Depends(limit_concurrency)])
|
|
async def parse_pdf(
|
|
files: List[UploadFile] = File(..., description="Upload pdf, image, or Word files for parsing"),
|
|
output_dir: str = Form("./output", description="Output local directory"),
|
|
lang_list: List[str] = Form(["ch"]),
|
|
backend: str = Form("hybrid-auto-engine"),
|
|
parse_method: str = Form("auto"),
|
|
formula_enable: bool = Form(True),
|
|
table_enable: bool = Form(True),
|
|
server_url: Optional[str] = Form(None),
|
|
return_md: bool = Form(True),
|
|
return_middle_json: bool = Form(False),
|
|
return_model_output: bool = Form(False),
|
|
return_content_list: bool = Form(False),
|
|
return_images: bool = Form(False),
|
|
response_format_zip: bool = Form(False),
|
|
start_page_id: int = Form(0),
|
|
end_page_id: int = Form(99999),
|
|
):
|
|
# 从 app 实例状态中获取配置 (FastAPI 实例会在下方创建)
|
|
from fastapi import Request
|
|
config = getattr(app.state, "config", {})
|
|
|
|
try:
|
|
unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
|
|
os.makedirs(unique_dir, exist_ok=True)
|
|
|
|
pdf_file_names = []
|
|
pdf_bytes_list = []
|
|
|
|
for file in files:
|
|
content = await file.read()
|
|
file_path = Path(file.filename)
|
|
temp_path = Path(unique_dir) / file_path.name
|
|
with open(temp_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
file_suffix = guess_suffix_by_path(temp_path)
|
|
if file_suffix in pdf_suffixes + image_suffixes + word_suffixes:
|
|
try:
|
|
pdf_bytes = read_fn(temp_path)
|
|
pdf_bytes_list.append(pdf_bytes)
|
|
pdf_file_names.append(file_path.stem)
|
|
os.remove(temp_path)
|
|
except Exception as e:
|
|
return JSONResponse(status_code=400, content={"error": f"Failed to load file: {str(e)}"})
|
|
else:
|
|
return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {file_suffix}"})
|
|
|
|
actual_lang_list = lang_list
|
|
if len(actual_lang_list) != len(pdf_file_names):
|
|
actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
|
|
|
|
await aio_do_parse(
|
|
output_dir=unique_dir, pdf_file_names=pdf_file_names, pdf_bytes_list=pdf_bytes_list,
|
|
p_lang_list=actual_lang_list, backend=backend, parse_method=parse_method,
|
|
formula_enable=formula_enable, table_enable=table_enable, server_url=server_url,
|
|
f_draw_layout_bbox=False, f_draw_span_bbox=False, f_dump_md=return_md,
|
|
f_dump_middle_json=return_middle_json, f_dump_model_output=return_model_output,
|
|
f_dump_orig_pdf=False, f_dump_content_list=return_content_list,
|
|
start_page_id=start_page_id, end_page_id=end_page_id, **config
|
|
)
|
|
|
|
if response_format_zip:
|
|
zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_results_")
|
|
os.close(zip_fd)
|
|
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
|
for pdf_name in pdf_file_names:
|
|
safe_pdf_name = sanitize_filename(pdf_name)
|
|
# 路径匹配逻辑
|
|
if backend.startswith("pipeline"):
|
|
p_dir = os.path.join(unique_dir, pdf_name, parse_method)
|
|
elif backend.startswith("vlm"):
|
|
p_dir = os.path.join(unique_dir, pdf_name, "vlm")
|
|
else:
|
|
p_dir = os.path.join(unique_dir, pdf_name, f"hybrid_{parse_method}")
|
|
|
|
if not os.path.exists(p_dir): continue
|
|
if return_md:
|
|
path = os.path.join(p_dir, f"{pdf_name}.md")
|
|
if os.path.exists(path): zf.write(path,
|
|
arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}.md"))
|
|
if return_images:
|
|
images_dir = os.path.join(p_dir, "images")
|
|
for img in glob.glob(os.path.join(glob.escape(images_dir), "*.jpg")):
|
|
zf.write(img, arcname=os.path.join(safe_pdf_name, "images", os.path.basename(img)))
|
|
|
|
return FileResponse(path=zip_path, media_type="application/zip", filename="results.zip",
|
|
background=BackgroundTask(cleanup_file, zip_path))
|
|
else:
|
|
result_dict = {}
|
|
for pdf_name in pdf_file_names:
|
|
result_dict[pdf_name] = {}
|
|
data = result_dict[pdf_name]
|
|
if backend.startswith("pipeline"):
|
|
p_dir = os.path.join(unique_dir, pdf_name, parse_method)
|
|
elif backend.startswith("vlm"):
|
|
p_dir = os.path.join(unique_dir, pdf_name, "vlm")
|
|
else:
|
|
p_dir = os.path.join(unique_dir, pdf_name, f"hybrid_{parse_method}")
|
|
|
|
if os.path.exists(p_dir):
|
|
if return_md: data["md_content"] = get_infer_result(".md", pdf_name, p_dir)
|
|
if return_images:
|
|
img_dir = os.path.join(p_dir, "images")
|
|
data["images"] = {os.path.basename(p): f"data:image/jpeg;base64,{encode_image(p)}"
|
|
for p in glob.glob(os.path.join(glob.escape(img_dir), "*.jpg"))}
|
|
|
|
return JSONResponse(status_code=200,
|
|
content={"backend": backend, "version": __version__, "results": result_dict})
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
return JSONResponse(status_code=500, content={"error": f"Internal Error: {str(e)}"})
|
|
|
|
|
|
# --- FastAPI 核心应用 ---
|
|
def create_app():
|
|
enable_docs = str(os.getenv("MINERU_API_ENABLE_FASTAPI_DOCS", "1")).lower() in ("1", "true", "yes")
|
|
app = FastAPI(
|
|
openapi_url="/openapi.json" if enable_docs else None,
|
|
docs_url="/docs" if enable_docs else None,
|
|
redoc_url="/redoc" if enable_docs else None,
|
|
)
|
|
|
|
global _request_semaphore
|
|
try:
|
|
mcr = int(os.getenv("MINERU_API_MAX_CONCURRENT_REQUESTS", "0"))
|
|
if mcr > 0:
|
|
_request_semaphore = asyncio.Semaphore(mcr)
|
|
logger.info(f"Concurrency limited to {mcr}")
|
|
except:
|
|
pass
|
|
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"],
|
|
allow_headers=["*"])
|
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
|
|
|
# 1. 首先挂载 API 路由组 (处理所有 /api/* 请求)
|
|
app.include_router(api_router)
|
|
|
|
# 2. 根路径重定向或特定处理 (可选)
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
# 3. 最后挂载静态文件服务 (处理剩下的所有请求,如 /, /index.html, /assets/*)
|
|
static_dir = Path(__file__).parent / "static" / "web"
|
|
if static_dir.exists():
|
|
logger.info(f"Mounting static files from {static_dir}")
|
|
app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
|
|
else:
|
|
logger.warning("Static directory not found, web UI will not be available.")
|
|
|
|
return app
|
|
|
|
|
|
app = create_app()
|
|
|
|
|
|
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
|
|
@click.pass_context
|
|
@click.option('--host', default='127.0.0.1')
|
|
@click.option('--port', default=8000, type=int)
|
|
@click.option('--reload', is_flag=True)
|
|
def main(ctx, host, port, reload, **kwargs):
|
|
kwargs.update(arg_parse(ctx))
|
|
app.state.config = kwargs
|
|
|
|
mcr = str(kwargs.get("mineru_api_max_concurrent_requests", "0") or "0")
|
|
os.environ["MINERU_API_MAX_CONCURRENT_REQUESTS"] = mcr
|
|
|
|
uvicorn.run("mineru.cli.fast_api:app", host=host, port=port, reload=reload)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |