155 lines
5.8 KiB
Python
155 lines
5.8 KiB
Python
import html
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from loguru import logger
|
|
from rapid_table import ModelType, RapidTable, RapidTableInput
|
|
from rapid_table.utils import RapidTableOutput
|
|
from tqdm import tqdm
|
|
|
|
from mineru.model.ocr.pytorch_paddle import PytorchPaddleOCR
|
|
from mineru.utils.enum_class import ModelPath
|
|
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
|
|
|
def escape_html(input_string):
|
|
"""Escape HTML Entities."""
|
|
return html.escape(input_string)
|
|
|
|
|
|
class CustomRapidTable(RapidTable):
|
|
def __init__(self, cfg: RapidTableInput):
|
|
import logging
|
|
# 通过环境变量控制日志级别
|
|
logging.disable(logging.INFO)
|
|
super().__init__(cfg)
|
|
def __call__(self, img_contents, ocr_results=None, batch_size=1):
|
|
if not isinstance(img_contents, list):
|
|
img_contents = [img_contents]
|
|
|
|
s = time.perf_counter()
|
|
|
|
results = RapidTableOutput()
|
|
|
|
total_nums = len(img_contents)
|
|
|
|
with tqdm(total=total_nums, desc="Table-wireless Predict") as pbar:
|
|
for start_i in range(0, total_nums, batch_size):
|
|
end_i = min(total_nums, start_i + batch_size)
|
|
|
|
imgs = self._load_imgs(img_contents[start_i:end_i])
|
|
|
|
pred_structures, cell_bboxes = self.table_structure(imgs)
|
|
logic_points = self.table_matcher.decode_logic_points(pred_structures)
|
|
|
|
dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
|
|
pred_htmls = self.table_matcher(
|
|
pred_structures, cell_bboxes, dt_boxes, rec_res
|
|
)
|
|
|
|
results.pred_htmls.extend(pred_htmls)
|
|
# 更新进度条
|
|
pbar.update(end_i - start_i)
|
|
|
|
elapse = time.perf_counter() - s
|
|
results.elapse = elapse / total_nums
|
|
return results
|
|
|
|
|
|
class RapidTableModel():
|
|
def __init__(self, ocr_engine):
|
|
slanet_plus_model_path = os.path.join(
|
|
auto_download_and_get_model_root_path(ModelPath.slanet_plus),
|
|
ModelPath.slanet_plus,
|
|
)
|
|
input_args = RapidTableInput(
|
|
model_type=ModelType.SLANETPLUS,
|
|
model_dir_or_path=slanet_plus_model_path,
|
|
use_ocr=False
|
|
)
|
|
self.table_model = CustomRapidTable(input_args)
|
|
self.ocr_engine = ocr_engine
|
|
|
|
def predict(self, image, ocr_result=None):
|
|
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
|
# Continue with OCR on potentially rotated image
|
|
|
|
if not ocr_result:
|
|
raw_ocr_result = self.ocr_engine.ocr(bgr_image)[0]
|
|
# 分离边界框、文本和置信度
|
|
boxes = []
|
|
texts = []
|
|
scores = []
|
|
for item in raw_ocr_result:
|
|
if len(item) == 3:
|
|
boxes.append(item[0])
|
|
texts.append(escape_html(item[1]))
|
|
scores.append(item[2])
|
|
elif len(item) == 2 and isinstance(item[1], tuple):
|
|
boxes.append(item[0])
|
|
texts.append(escape_html(item[1][0]))
|
|
scores.append(item[1][1])
|
|
# 按照 rapid_table 期望的格式构建 ocr_results
|
|
ocr_result = [(boxes, texts, scores)]
|
|
|
|
if ocr_result:
|
|
try:
|
|
table_results = self.table_model(img_contents=np.asarray(image), ocr_results=ocr_result)
|
|
html_code = table_results.pred_htmls
|
|
table_cell_bboxes = table_results.cell_bboxes
|
|
logic_points = table_results.logic_points
|
|
elapse = table_results.elapse
|
|
return html_code, table_cell_bboxes, logic_points, elapse
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
|
|
return None, None, None, None
|
|
|
|
def batch_predict(self, table_res_list: List[dict], batch_size: int = 4):
|
|
not_none_table_res_list = []
|
|
for table_res in table_res_list:
|
|
if table_res.get("ocr_result", None):
|
|
not_none_table_res_list.append(table_res)
|
|
|
|
if not_none_table_res_list:
|
|
img_contents = [table_res["table_img"] for table_res in not_none_table_res_list]
|
|
ocr_results = []
|
|
# ocr_results需要按照rapid_table期望的格式构建
|
|
for table_res in not_none_table_res_list:
|
|
raw_ocr_result = table_res["ocr_result"]
|
|
boxes = []
|
|
texts = []
|
|
scores = []
|
|
for item in raw_ocr_result:
|
|
if len(item) == 3:
|
|
boxes.append(item[0])
|
|
texts.append(escape_html(item[1]))
|
|
scores.append(item[2])
|
|
elif len(item) == 2 and isinstance(item[1], tuple):
|
|
boxes.append(item[0])
|
|
texts.append(escape_html(item[1][0]))
|
|
scores.append(item[1][1])
|
|
ocr_results.append((boxes, texts, scores))
|
|
table_results = self.table_model(img_contents=img_contents, ocr_results=ocr_results, batch_size=batch_size)
|
|
|
|
for i, result in enumerate(table_results.pred_htmls):
|
|
if result:
|
|
not_none_table_res_list[i]['table_res']['html'] = result
|
|
|
|
if __name__ == '__main__':
|
|
ocr_engine= PytorchPaddleOCR(
|
|
det_db_box_thresh=0.5,
|
|
det_db_unclip_ratio=1.6,
|
|
enable_merge_det_boxes=False,
|
|
)
|
|
table_model = RapidTableModel(ocr_engine)
|
|
img_path = Path(r"D:\project\20240729ocrtest\pythonProject\images\601c939cc6dabaf07af763e2f935f54896d0251f37cc47beb7fc6b069353455d.jpg")
|
|
image = cv2.imread(str(img_path))
|
|
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(image)
|
|
print(html_code)
|
|
|