484 lines
19 KiB
Python
484 lines
19 KiB
Python
import sys
|
||
import uuid
|
||
import os
|
||
import re
|
||
import tempfile
|
||
import asyncio
|
||
import threading
|
||
import uvicorn
|
||
import click
|
||
import zipfile
|
||
from pathlib import Path
|
||
import glob
|
||
from fastapi import Depends, FastAPI, HTTPException, UploadFile, File, Form, APIRouter, Header
|
||
from fastapi.middleware.gzip import GZipMiddleware
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse, FileResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from starlette.background import BackgroundTask
|
||
from typing import List, Optional
|
||
from loguru import logger
|
||
from base64 import b64encode
|
||
|
||
# MinerU 内部导入
|
||
from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes, word_suffixes
|
||
from mineru.utils.cli_parser import arg_parse
|
||
from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path
|
||
from mineru.version import __version__
|
||
|
||
# --- 日志配置 ---
|
||
log_level = os.getenv("MINERU_LOG_LEVEL", "INFO").upper()
|
||
logger.remove()
|
||
logger.add(sys.stderr, level=log_level)
|
||
|
||
# --- 全局变量与辅助类 ---
|
||
_request_semaphore: Optional[asyncio.Semaphore] = None
|
||
|
||
# --- 任务进度跟踪 ---
|
||
_task_progress: dict = {}
|
||
|
||
|
||
def _update_task_progress(task_id: Optional[str], progress: int, stage: str):
|
||
"""更新任务进度(安全调用,task_id 为 None 时静默跳过)"""
|
||
if task_id and task_id in _task_progress:
|
||
_task_progress[task_id]["progress"] = min(progress, 100)
|
||
_task_progress[task_id]["stage"] = stage
|
||
|
||
|
||
class ProgressTracker:
|
||
def __init__(self):
|
||
self.progress = 0
|
||
self.status = "初始化"
|
||
|
||
def update(self, progress: float, status: str):
|
||
self.progress = progress
|
||
self.status = status
|
||
|
||
def get_progress(self):
|
||
return {"progress": self.progress, "status": self.status}
|
||
|
||
|
||
# --- 当前活跃任务ID(用于日志捕获) ---
|
||
_current_task_id: Optional[str] = None
|
||
|
||
|
||
def _log_progress_sink(message):
|
||
"""loguru 日志捕获:根据日志内容自动更新进度"""
|
||
task_id = _current_task_id
|
||
if not task_id or task_id not in _task_progress:
|
||
return
|
||
try:
|
||
msg = str(message)
|
||
except Exception:
|
||
return
|
||
# VLM 引擎加载阶段 (5% ~ 40%)
|
||
if "Automatically detected platform" in msg:
|
||
_update_task_progress(task_id, 7, "检测计算平台")
|
||
elif "Resolved architecture" in msg:
|
||
_update_task_progress(task_id, 10, "解析模型架构")
|
||
elif "Starting to load model" in msg:
|
||
_update_task_progress(task_id, 12, "开始加载模型权重")
|
||
elif "Loading weights took" in msg:
|
||
_update_task_progress(task_id, 18, "模型权重加载完成")
|
||
elif "Model loading took" in msg:
|
||
_update_task_progress(task_id, 22, "模型加载完成,初始化引擎")
|
||
elif "Dynamo bytecode transform" in msg:
|
||
_update_task_progress(task_id, 27, "编译优化计算图")
|
||
elif "torch.compile takes" in msg:
|
||
_update_task_progress(task_id, 30, "计算图编译完成")
|
||
elif "Available KV cache memory" in msg:
|
||
_update_task_progress(task_id, 32, "分配显存缓存")
|
||
elif "Graph capturing finished" in msg:
|
||
_update_task_progress(task_id, 36, "CUDA图捕获完成")
|
||
elif "init engine" in msg and "took" in msg:
|
||
_update_task_progress(task_id, 38, "VLM引擎初始化完成")
|
||
elif "get vllm" in msg and "predictor cost" in msg:
|
||
_update_task_progress(task_id, 40, "VLM预测器就绪")
|
||
elif "hybrid batch ratio" in msg:
|
||
_update_task_progress(task_id, 42, "开始文档分析")
|
||
# VLM 推理阶段 (40% ~ 65%) 由 progress_callback 处理
|
||
# Pipeline 模型加载和推理阶段 (65% ~ 97%)
|
||
elif "Target directory already exists" in msg:
|
||
cur = _task_progress.get(task_id, {}).get("progress", 0)
|
||
if cur < 75:
|
||
_update_task_progress(task_id, max(cur, 66), "加载Pipeline模型")
|
||
elif "MFD Predict:" in msg:
|
||
_update_task_progress(task_id, 78, "数学公式检测")
|
||
elif "MFR Predict:" in msg:
|
||
_update_task_progress(task_id, 85, "数学公式识别")
|
||
elif "OCR-det:" in msg:
|
||
_update_task_progress(task_id, 90, "文字区域检测")
|
||
elif "OCR-rec Predict:" in msg:
|
||
_update_task_progress(task_id, 95, "文字识别")
|
||
elif "local output dir is" in msg:
|
||
_update_task_progress(task_id, 97, "保存输出结果")
|
||
|
||
|
||
class _StderrProgressCapture:
|
||
"""后台线程捕获 stderr 中的 tqdm 进度条输出,解析并更新任务进度"""
|
||
|
||
# tqdm 进度条模式:名称: 百分比|...| 当前/总数
|
||
_PATTERNS = [
|
||
(re.compile(r'Two Step Extraction:\s*(\d+)%.*?(\d+)/(\d+)'), 'extract'),
|
||
(re.compile(r'MFD Predict:\s*(\d+)%.*?(\d+)/(\d+)'), 'mfd'),
|
||
(re.compile(r'MFR Predict:\s*(\d+)%.*?(\d+)/(\d+)'), 'mfr'),
|
||
(re.compile(r'OCR-det:\s*(\d+)%.*?(\d+)/(\d+)'), 'ocr_det'),
|
||
(re.compile(r'OCR-rec Predict:\s*(\d+)%.*?(\d+)/(\d+)'), 'ocr_rec'),
|
||
(re.compile(r'Loading safetensors.*?:\s*(\d+)%'), 'load_model'),
|
||
(re.compile(r'Capturing CUDA graphs.*?:\s*(\d+)%'), 'cuda_graph'),
|
||
]
|
||
|
||
# 各阶段的进度映射范围 [start%, end%]
|
||
_RANGES = {
|
||
'load_model': (12, 18),
|
||
'cuda_graph': (33, 37),
|
||
'extract': (42, 65),
|
||
'mfd': (75, 80),
|
||
'mfr': (80, 87),
|
||
'ocr_det': (87, 92),
|
||
'ocr_rec': (92, 96),
|
||
}
|
||
|
||
_STAGE_LABELS = {
|
||
'load_model': '加载模型权重',
|
||
'cuda_graph': '捕获CUDA计算图',
|
||
'extract': 'VLM文档分析',
|
||
'mfd': '数学公式检测',
|
||
'mfr': '数学公式识别',
|
||
'ocr_det': '文字区域检测',
|
||
'ocr_rec': '文字识别',
|
||
}
|
||
|
||
def __init__(self, task_id: str):
|
||
self.task_id = task_id
|
||
self._active = False
|
||
self._thread: Optional[threading.Thread] = None
|
||
self._orig_stderr = None
|
||
|
||
def start(self):
|
||
self._active = True
|
||
self._orig_stderr = sys.stderr
|
||
self._thread = threading.Thread(target=self._reader_loop, daemon=True)
|
||
self._thread.start()
|
||
|
||
def stop(self):
|
||
self._active = False
|
||
if self._thread and self._thread.is_alive():
|
||
self._thread.join(timeout=2)
|
||
self._thread = None
|
||
|
||
def _reader_loop(self):
|
||
buf = ""
|
||
orig = self._orig_stderr
|
||
while self._active:
|
||
try:
|
||
ch = orig.read(1)
|
||
if not ch:
|
||
break
|
||
buf += ch
|
||
# tqdm 用 \r 更新同一行,\n 表示新行
|
||
if ch == '\r' or ch == '\n':
|
||
if buf.strip():
|
||
self._parse_line(buf.strip())
|
||
buf = ""
|
||
except Exception:
|
||
break
|
||
|
||
def _parse_line(self, line: str):
|
||
for pattern, stage in self._PATTERNS:
|
||
m = pattern.search(line)
|
||
if m:
|
||
pct = int(m.group(1))
|
||
lo, hi = self._RANGES.get(stage, (0, 100))
|
||
mapped = lo + int((hi - lo) * pct / 100)
|
||
label = self._STAGE_LABELS.get(stage, stage)
|
||
if stage == 'extract' and len(m.groups()) >= 3:
|
||
cur, total = m.group(2), m.group(3)
|
||
label = f"VLM文档分析 ({cur}/{total}页)"
|
||
elif stage in ('mfd', 'mfr', 'ocr_det', 'ocr_rec') and len(m.groups()) >= 3:
|
||
cur, total = m.group(2), m.group(3)
|
||
label = f"{label} ({cur}/{total})"
|
||
_update_task_progress(self.task_id, mapped, label)
|
||
break
|
||
|
||
|
||
async def limit_concurrency():
|
||
if _request_semaphore is not None:
|
||
if _request_semaphore.locked():
|
||
raise HTTPException(
|
||
status_code=503,
|
||
detail=f"Server is at maximum capacity: {os.getenv('MINERU_API_MAX_CONCURRENT_REQUESTS', '0')}."
|
||
)
|
||
async with _request_semaphore:
|
||
yield
|
||
else:
|
||
yield
|
||
|
||
|
||
def sanitize_filename(filename: str) -> str:
|
||
sanitized = re.sub(r'[/\\\.]{2,}|[/\\]', '', filename)
|
||
sanitized = re.sub(r'[^\w.-]', '_', sanitized, flags=re.UNICODE)
|
||
if sanitized.startswith('.'):
|
||
sanitized = '_' + sanitized[1:]
|
||
return sanitized or 'unnamed'
|
||
|
||
|
||
def cleanup_file(file_path: str) -> None:
|
||
try:
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
except Exception as e:
|
||
logger.warning(f"fail clean file {file_path}: {e}")
|
||
|
||
|
||
def encode_image(image_path: str) -> str:
|
||
with open(image_path, "rb") as f:
|
||
return b64encode(f.read()).decode()
|
||
|
||
|
||
def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]:
|
||
result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
|
||
if os.path.exists(result_file_path):
|
||
with open(result_file_path, "r", encoding="utf-8") as fp:
|
||
return fp.read()
|
||
return None
|
||
|
||
|
||
# --- 创建 API 路由组 ---
|
||
api_router = APIRouter(prefix="/api")
|
||
|
||
|
||
@api_router.get("/parse_progress/{task_id}")
|
||
async def get_parse_progress(task_id: str):
|
||
"""查询解析任务的实时进度"""
|
||
if task_id not in _task_progress:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
return _task_progress[task_id]
|
||
|
||
|
||
@api_router.post(path="/file_parse", dependencies=[Depends(limit_concurrency)])
|
||
async def parse_pdf(
|
||
files: List[UploadFile] = File(..., description="Upload pdf, image, or Word files for parsing"),
|
||
output_dir: str = Form("./output", description="Output local directory"),
|
||
lang_list: List[str] = Form(["ch"]),
|
||
backend: str = Form("hybrid-auto-engine"),
|
||
parse_method: str = Form("auto"),
|
||
formula_enable: bool = Form(True),
|
||
table_enable: bool = Form(True),
|
||
server_url: Optional[str] = Form(None),
|
||
return_md: bool = Form(True),
|
||
return_middle_json: bool = Form(False),
|
||
return_model_output: bool = Form(False),
|
||
return_content_list: bool = Form(False),
|
||
return_images: bool = Form(False),
|
||
response_format_zip: bool = Form(False),
|
||
start_page_id: int = Form(0),
|
||
end_page_id: int = Form(99999),
|
||
x_task_id: Optional[str] = Header(None),
|
||
):
|
||
# 从 app 实例状态中获取配置 (FastAPI 实例会在下方创建)
|
||
from fastapi import Request
|
||
config = getattr(app.state, "config", {})
|
||
|
||
# 初始化进度跟踪
|
||
task_id = x_task_id or str(uuid.uuid4())
|
||
file_names_str = ", ".join(f.filename or "unknown" for f in files)
|
||
_task_progress[task_id] = {
|
||
"progress": 0, "stage": "准备中", "status": "processing",
|
||
"error": None, "file_names": file_names_str,
|
||
}
|
||
|
||
try:
|
||
unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
|
||
os.makedirs(unique_dir, exist_ok=True)
|
||
|
||
pdf_file_names = []
|
||
pdf_bytes_list = []
|
||
|
||
for file in files:
|
||
content = await file.read()
|
||
file_path = Path(file.filename)
|
||
temp_path = Path(unique_dir) / file_path.name
|
||
with open(temp_path, "wb") as f:
|
||
f.write(content)
|
||
|
||
file_suffix = guess_suffix_by_path(temp_path)
|
||
if file_suffix in pdf_suffixes + image_suffixes + word_suffixes:
|
||
try:
|
||
pdf_bytes = read_fn(temp_path)
|
||
pdf_bytes_list.append(pdf_bytes)
|
||
pdf_file_names.append(file_path.stem)
|
||
os.remove(temp_path)
|
||
except Exception as e:
|
||
return JSONResponse(status_code=400, content={"error": f"Failed to load file: {str(e)}"})
|
||
else:
|
||
return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {file_suffix}"})
|
||
|
||
actual_lang_list = lang_list
|
||
if len(actual_lang_list) != len(pdf_file_names):
|
||
actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
|
||
|
||
# 进度回调:将 common.py 内部进度映射为细粒度阶段
|
||
def progress_callback(pct, msg):
|
||
msg_str = str(msg)
|
||
if "处理文件" in msg_str:
|
||
_update_task_progress(task_id, 40, f"VLM文档分析: {msg_str}")
|
||
elif "完成文件" in msg_str:
|
||
_update_task_progress(task_id, 65, f"VLM分析完成: {msg_str}")
|
||
else:
|
||
_update_task_progress(task_id, int(40 + pct * 0.25), msg_str)
|
||
|
||
_update_task_progress(task_id, 5, "开始解析文档")
|
||
|
||
# 设置日志捕获(将当前任务ID绑定到日志 sink)
|
||
global _current_task_id
|
||
_current_task_id = task_id
|
||
sink_id = logger.add(_log_progress_sink, level="DEBUG")
|
||
|
||
# 启动 stderr 捕获(解析 tqdm 进度条输出)
|
||
stderr_capture = _StderrProgressCapture(task_id)
|
||
stderr_capture.start()
|
||
|
||
await aio_do_parse(
|
||
output_dir=unique_dir, pdf_file_names=pdf_file_names, pdf_bytes_list=pdf_bytes_list,
|
||
p_lang_list=actual_lang_list, backend=backend, parse_method=parse_method,
|
||
formula_enable=formula_enable, table_enable=table_enable, server_url=server_url,
|
||
f_draw_layout_bbox=False, f_draw_span_bbox=False, f_dump_md=return_md,
|
||
f_dump_middle_json=return_middle_json, f_dump_model_output=return_model_output,
|
||
f_dump_orig_pdf=False, f_dump_content_list=return_content_list,
|
||
start_page_id=start_page_id, end_page_id=end_page_id,
|
||
progress_callback=progress_callback, **config
|
||
)
|
||
|
||
_update_task_progress(task_id, 97, "生成结果文件")
|
||
|
||
_update_task_progress(task_id, 100, "转换完成")
|
||
_task_progress[task_id]["status"] = "completed"
|
||
|
||
# 清理日志捕获和 stderr 捕获
|
||
stderr_capture.stop()
|
||
logger.remove(sink_id)
|
||
_current_task_id = None
|
||
|
||
if response_format_zip:
|
||
zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_results_")
|
||
os.close(zip_fd)
|
||
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
||
for pdf_name in pdf_file_names:
|
||
safe_pdf_name = sanitize_filename(pdf_name)
|
||
# 路径匹配逻辑
|
||
if backend.startswith("pipeline"):
|
||
p_dir = os.path.join(unique_dir, pdf_name, parse_method)
|
||
elif backend.startswith("vlm"):
|
||
p_dir = os.path.join(unique_dir, pdf_name, "vlm")
|
||
else:
|
||
p_dir = os.path.join(unique_dir, pdf_name, f"hybrid_{parse_method}")
|
||
|
||
if not os.path.exists(p_dir): continue
|
||
if return_md:
|
||
path = os.path.join(p_dir, f"{pdf_name}.md")
|
||
if os.path.exists(path): zf.write(path,
|
||
arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}.md"))
|
||
if return_images:
|
||
images_dir = os.path.join(p_dir, "images")
|
||
for img in glob.glob(os.path.join(glob.escape(images_dir), "*.jpg")):
|
||
zf.write(img, arcname=os.path.join(safe_pdf_name, "images", os.path.basename(img)))
|
||
|
||
return FileResponse(path=zip_path, media_type="application/zip", filename="results.zip",
|
||
background=BackgroundTask(cleanup_file, zip_path))
|
||
else:
|
||
result_dict = {}
|
||
for pdf_name in pdf_file_names:
|
||
result_dict[pdf_name] = {}
|
||
data = result_dict[pdf_name]
|
||
if backend.startswith("pipeline"):
|
||
p_dir = os.path.join(unique_dir, pdf_name, parse_method)
|
||
elif backend.startswith("vlm"):
|
||
p_dir = os.path.join(unique_dir, pdf_name, "vlm")
|
||
else:
|
||
p_dir = os.path.join(unique_dir, pdf_name, f"hybrid_{parse_method}")
|
||
|
||
if os.path.exists(p_dir):
|
||
if return_md: data["md_content"] = get_infer_result(".md", pdf_name, p_dir)
|
||
if return_images:
|
||
img_dir = os.path.join(p_dir, "images")
|
||
data["images"] = {os.path.basename(p): f"data:image/jpeg;base64,{encode_image(p)}"
|
||
for p in glob.glob(os.path.join(glob.escape(img_dir), "*.jpg"))}
|
||
|
||
return JSONResponse(status_code=200,
|
||
content={"backend": backend, "version": __version__, "results": result_dict})
|
||
except Exception as e:
|
||
logger.exception(e)
|
||
# 清理日志捕获和 stderr 捕获
|
||
try:
|
||
stderr_capture.stop()
|
||
logger.remove(sink_id)
|
||
_current_task_id = None
|
||
except Exception:
|
||
pass
|
||
_task_progress[task_id]["status"] = "failed"
|
||
_task_progress[task_id]["error"] = str(e)
|
||
return JSONResponse(status_code=500, content={"error": f"Internal Error: {str(e)}"})
|
||
|
||
|
||
# --- FastAPI 核心应用 ---
|
||
def create_app():
|
||
enable_docs = str(os.getenv("MINERU_API_ENABLE_FASTAPI_DOCS", "1")).lower() in ("1", "true", "yes")
|
||
app = FastAPI(
|
||
openapi_url="/openapi.json" if enable_docs else None,
|
||
docs_url="/docs" if enable_docs else None,
|
||
redoc_url="/redoc" if enable_docs else None,
|
||
)
|
||
|
||
global _request_semaphore
|
||
try:
|
||
mcr = int(os.getenv("MINERU_API_MAX_CONCURRENT_REQUESTS", "0"))
|
||
if mcr > 0:
|
||
_request_semaphore = asyncio.Semaphore(mcr)
|
||
logger.info(f"Concurrency limited to {mcr}")
|
||
except:
|
||
pass
|
||
|
||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"],
|
||
allow_headers=["*"])
|
||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||
|
||
# 1. 首先挂载 API 路由组 (处理所有 /api/* 请求)
|
||
app.include_router(api_router)
|
||
|
||
# 2. 根路径重定向或特定处理 (可选)
|
||
@app.get("/health")
|
||
async def health():
|
||
return {"status": "ok"}
|
||
|
||
# 3. 最后挂载静态文件服务 (处理剩下的所有请求,如 /, /index.html, /assets/*)
|
||
static_dir = Path(__file__).parent / "static" / "web"
|
||
if static_dir.exists():
|
||
logger.info(f"Mounting static files from {static_dir}")
|
||
app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
|
||
else:
|
||
logger.warning("Static directory not found, web UI will not be available.")
|
||
|
||
return app
|
||
|
||
|
||
app = create_app()
|
||
|
||
|
||
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
|
||
@click.pass_context
|
||
@click.option('--host', default='127.0.0.1')
|
||
@click.option('--port', default=8000, type=int)
|
||
@click.option('--reload', is_flag=True)
|
||
def main(ctx, host, port, reload, **kwargs):
|
||
kwargs.update(arg_parse(ctx))
|
||
app.state.config = kwargs
|
||
|
||
mcr = str(kwargs.get("mineru_api_max_concurrent_requests", "0") or "0")
|
||
os.environ["MINERU_API_MAX_CONCURRENT_REQUESTS"] = mcr
|
||
|
||
uvicorn.run("mineru.cli.fast_api:app", host=host, port=port, reload=reload)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |