UnisMindMap/mineru/model/table/rec/RapidTable.py

155 lines
5.8 KiB
Python

import html
import os
import time
from pathlib import Path
from typing import List
import cv2
import numpy as np
from loguru import logger
from rapid_table import ModelType, RapidTable, RapidTableInput
from rapid_table.utils import RapidTableOutput
from tqdm import tqdm
from mineru.model.ocr.pytorch_paddle import PytorchPaddleOCR
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
def escape_html(input_string):
"""Escape HTML Entities."""
return html.escape(input_string)
class CustomRapidTable(RapidTable):
def __init__(self, cfg: RapidTableInput):
import logging
# 通过环境变量控制日志级别
logging.disable(logging.INFO)
super().__init__(cfg)
def __call__(self, img_contents, ocr_results=None, batch_size=1):
if not isinstance(img_contents, list):
img_contents = [img_contents]
s = time.perf_counter()
results = RapidTableOutput()
total_nums = len(img_contents)
with tqdm(total=total_nums, desc="Table-wireless Predict") as pbar:
for start_i in range(0, total_nums, batch_size):
end_i = min(total_nums, start_i + batch_size)
imgs = self._load_imgs(img_contents[start_i:end_i])
pred_structures, cell_bboxes = self.table_structure(imgs)
logic_points = self.table_matcher.decode_logic_points(pred_structures)
dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
pred_htmls = self.table_matcher(
pred_structures, cell_bboxes, dt_boxes, rec_res
)
results.pred_htmls.extend(pred_htmls)
# 更新进度条
pbar.update(end_i - start_i)
elapse = time.perf_counter() - s
results.elapse = elapse / total_nums
return results
class RapidTableModel():
def __init__(self, ocr_engine):
slanet_plus_model_path = os.path.join(
auto_download_and_get_model_root_path(ModelPath.slanet_plus),
ModelPath.slanet_plus,
)
input_args = RapidTableInput(
model_type=ModelType.SLANETPLUS,
model_dir_or_path=slanet_plus_model_path,
use_ocr=False
)
self.table_model = CustomRapidTable(input_args)
self.ocr_engine = ocr_engine
def predict(self, image, ocr_result=None):
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
# Continue with OCR on potentially rotated image
if not ocr_result:
raw_ocr_result = self.ocr_engine.ocr(bgr_image)[0]
# 分离边界框、文本和置信度
boxes = []
texts = []
scores = []
for item in raw_ocr_result:
if len(item) == 3:
boxes.append(item[0])
texts.append(escape_html(item[1]))
scores.append(item[2])
elif len(item) == 2 and isinstance(item[1], tuple):
boxes.append(item[0])
texts.append(escape_html(item[1][0]))
scores.append(item[1][1])
# 按照 rapid_table 期望的格式构建 ocr_results
ocr_result = [(boxes, texts, scores)]
if ocr_result:
try:
table_results = self.table_model(img_contents=np.asarray(image), ocr_results=ocr_result)
html_code = table_results.pred_htmls
table_cell_bboxes = table_results.cell_bboxes
logic_points = table_results.logic_points
elapse = table_results.elapse
return html_code, table_cell_bboxes, logic_points, elapse
except Exception as e:
logger.exception(e)
return None, None, None, None
def batch_predict(self, table_res_list: List[dict], batch_size: int = 4):
not_none_table_res_list = []
for table_res in table_res_list:
if table_res.get("ocr_result", None):
not_none_table_res_list.append(table_res)
if not_none_table_res_list:
img_contents = [table_res["table_img"] for table_res in not_none_table_res_list]
ocr_results = []
# ocr_results需要按照rapid_table期望的格式构建
for table_res in not_none_table_res_list:
raw_ocr_result = table_res["ocr_result"]
boxes = []
texts = []
scores = []
for item in raw_ocr_result:
if len(item) == 3:
boxes.append(item[0])
texts.append(escape_html(item[1]))
scores.append(item[2])
elif len(item) == 2 and isinstance(item[1], tuple):
boxes.append(item[0])
texts.append(escape_html(item[1][0]))
scores.append(item[1][1])
ocr_results.append((boxes, texts, scores))
table_results = self.table_model(img_contents=img_contents, ocr_results=ocr_results, batch_size=batch_size)
for i, result in enumerate(table_results.pred_htmls):
if result:
not_none_table_res_list[i]['table_res']['html'] = result
if __name__ == '__main__':
ocr_engine= PytorchPaddleOCR(
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
enable_merge_det_boxes=False,
)
table_model = RapidTableModel(ocr_engine)
img_path = Path(r"D:\project\20240729ocrtest\pythonProject\images\601c939cc6dabaf07af763e2f935f54896d0251f37cc47beb7fc6b069353455d.jpg")
image = cv2.imread(str(img_path))
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(image)
print(html_code)