493 lines
16 KiB
Python
493 lines
16 KiB
Python
import os
|
|
import traceback
|
|
from enum import Enum
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from typing import List, Union, Dict, Any, Tuple, Optional
|
|
|
|
import cv2
|
|
import loguru
|
|
import numpy as np
|
|
from onnxruntime import (
|
|
GraphOptimizationLevel,
|
|
InferenceSession,
|
|
SessionOptions,
|
|
get_available_providers,
|
|
)
|
|
from PIL import Image, UnidentifiedImageError
|
|
|
|
|
|
root_dir = Path(__file__).resolve().parent
|
|
InputType = Union[str, np.ndarray, bytes, Path]
|
|
|
|
|
|
class EP(Enum):
|
|
CPU_EP = "CPUExecutionProvider"
|
|
|
|
|
|
class OrtInferSession:
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.logger = loguru.logger
|
|
|
|
model_path = config.get("model_path", None)
|
|
|
|
self.had_providers: List[str] = get_available_providers()
|
|
EP_list = self._get_ep_list()
|
|
|
|
sess_opt = self._init_sess_opts(config)
|
|
self.session = InferenceSession(
|
|
model_path,
|
|
sess_options=sess_opt,
|
|
providers=EP_list,
|
|
)
|
|
|
|
@staticmethod
|
|
def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
|
|
sess_opt = SessionOptions()
|
|
sess_opt.log_severity_level = 4
|
|
sess_opt.enable_cpu_mem_arena = False
|
|
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
cpu_nums = os.cpu_count()
|
|
intra_op_num_threads = config.get("intra_op_num_threads", -1)
|
|
if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
|
|
sess_opt.intra_op_num_threads = intra_op_num_threads
|
|
|
|
inter_op_num_threads = config.get("inter_op_num_threads", -1)
|
|
if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
|
|
sess_opt.inter_op_num_threads = inter_op_num_threads
|
|
|
|
return sess_opt
|
|
|
|
def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
|
|
cpu_provider_opts = {
|
|
"arena_extend_strategy": "kSameAsRequested",
|
|
}
|
|
EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
|
|
|
|
return EP_list
|
|
|
|
|
|
def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
|
|
input_dict = dict(zip(self.get_input_names(), input_content))
|
|
try:
|
|
return self.session.run(None, input_dict)
|
|
except Exception as e:
|
|
error_info = traceback.format_exc()
|
|
raise ONNXRuntimeError(error_info) from e
|
|
|
|
def get_input_names(self) -> List[str]:
|
|
return [v.name for v in self.session.get_inputs()]
|
|
|
|
|
|
class ONNXRuntimeError(Exception):
|
|
pass
|
|
|
|
|
|
class LoadImage:
|
|
def __init__(
|
|
self,
|
|
):
|
|
pass
|
|
|
|
def __call__(self, img: InputType) -> np.ndarray:
|
|
if not isinstance(img, InputType.__args__):
|
|
raise LoadImageError(
|
|
f"The img type {type(img)} does not in {InputType.__args__}"
|
|
)
|
|
|
|
img = self.load_img(img)
|
|
img = self.convert_img(img)
|
|
return img
|
|
|
|
def load_img(self, img: InputType) -> np.ndarray:
|
|
if isinstance(img, (str, Path)):
|
|
self.verify_exist(img)
|
|
try:
|
|
img = np.array(Image.open(img))
|
|
except UnidentifiedImageError as e:
|
|
raise LoadImageError(f"cannot identify image file {img}") from e
|
|
return img
|
|
|
|
if isinstance(img, bytes):
|
|
img = np.array(Image.open(BytesIO(img)))
|
|
return img
|
|
|
|
if isinstance(img, np.ndarray):
|
|
return img
|
|
|
|
raise LoadImageError(f"{type(img)} is not supported!")
|
|
|
|
def convert_img(self, img: np.ndarray):
|
|
if img.ndim == 2:
|
|
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
if img.ndim == 3:
|
|
channel = img.shape[2]
|
|
if channel == 1:
|
|
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
if channel == 2:
|
|
return self.cvt_two_to_three(img)
|
|
|
|
if channel == 4:
|
|
return self.cvt_four_to_three(img)
|
|
|
|
if channel == 3:
|
|
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
|
raise LoadImageError(
|
|
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
|
|
)
|
|
|
|
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
|
|
|
|
@staticmethod
|
|
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
|
|
"""RGBA → BGR"""
|
|
r, g, b, a = cv2.split(img)
|
|
new_img = cv2.merge((b, g, r))
|
|
|
|
not_a = cv2.bitwise_not(a)
|
|
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
|
|
|
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
|
|
new_img = cv2.add(new_img, not_a)
|
|
return new_img
|
|
|
|
@staticmethod
|
|
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
|
|
"""gray + alpha → BGR"""
|
|
img_gray = img[..., 0]
|
|
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
|
|
|
|
img_alpha = img[..., 1]
|
|
not_a = cv2.bitwise_not(img_alpha)
|
|
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
|
|
|
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
|
|
new_img = cv2.add(new_img, not_a)
|
|
return new_img
|
|
|
|
@staticmethod
|
|
def verify_exist(file_path: Union[str, Path]):
|
|
if not Path(file_path).exists():
|
|
raise LoadImageError(f"{file_path} does not exist.")
|
|
|
|
|
|
class LoadImageError(Exception):
|
|
pass
|
|
|
|
|
|
# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
|
|
# Set pillow_interp_codes according to the naming scheme used.
|
|
if Image is not None:
|
|
if hasattr(Image, "Resampling"):
|
|
pillow_interp_codes = {
|
|
"nearest": Image.Resampling.NEAREST,
|
|
"bilinear": Image.Resampling.BILINEAR,
|
|
"bicubic": Image.Resampling.BICUBIC,
|
|
"box": Image.Resampling.BOX,
|
|
"lanczos": Image.Resampling.LANCZOS,
|
|
"hamming": Image.Resampling.HAMMING,
|
|
}
|
|
else:
|
|
pillow_interp_codes = {
|
|
"nearest": Image.NEAREST,
|
|
"bilinear": Image.BILINEAR,
|
|
"bicubic": Image.BICUBIC,
|
|
"box": Image.BOX,
|
|
"lanczos": Image.LANCZOS,
|
|
"hamming": Image.HAMMING,
|
|
}
|
|
|
|
cv2_interp_codes = {
|
|
"nearest": cv2.INTER_NEAREST,
|
|
"bilinear": cv2.INTER_LINEAR,
|
|
"bicubic": cv2.INTER_CUBIC,
|
|
"area": cv2.INTER_AREA,
|
|
"lanczos": cv2.INTER_LANCZOS4,
|
|
}
|
|
|
|
|
|
def resize_img(img, scale, keep_ratio=True):
|
|
if keep_ratio:
|
|
# 缩小使用area更保真
|
|
if min(img.shape[:2]) > min(scale):
|
|
interpolation = "area"
|
|
else:
|
|
interpolation = "bicubic" # bilinear
|
|
img_new, scale_factor = imrescale(
|
|
img, scale, return_scale=True, interpolation=interpolation
|
|
)
|
|
# the w_scale and h_scale has minor difference
|
|
# a real fix should be done in the mmcv.imrescale in the future
|
|
new_h, new_w = img_new.shape[:2]
|
|
h, w = img.shape[:2]
|
|
w_scale = new_w / w
|
|
h_scale = new_h / h
|
|
else:
|
|
img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
|
|
return img_new, w_scale, h_scale
|
|
|
|
|
|
def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
|
|
"""Resize image while keeping the aspect ratio.
|
|
|
|
Args:
|
|
img (ndarray): The input image.
|
|
scale (float | tuple[int]): The scaling factor or maximum size.
|
|
If it is a float number, then the image will be rescaled by this
|
|
factor, else if it is a tuple of 2 integers, then the image will
|
|
be rescaled as large as possible within the scale.
|
|
return_scale (bool): Whether to return the scaling factor besides the
|
|
rescaled image.
|
|
interpolation (str): Same as :func:`resize`.
|
|
backend (str | None): Same as :func:`resize`.
|
|
|
|
Returns:
|
|
ndarray: The rescaled image.
|
|
"""
|
|
h, w = img.shape[:2]
|
|
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
|
|
rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
|
|
if return_scale:
|
|
return rescaled_img, scale_factor
|
|
else:
|
|
return rescaled_img
|
|
|
|
|
|
def imresize(
|
|
img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
|
|
):
|
|
"""Resize image to a given size.
|
|
|
|
Args:
|
|
img (ndarray): The input image.
|
|
size (tuple[int]): Target size (w, h).
|
|
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
|
interpolation (str): Interpolation method, accepted values are
|
|
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
|
|
backend, "nearest", "bilinear" for 'pillow' backend.
|
|
out (ndarray): The output destination.
|
|
backend (str | None): The image resize backend type. Options are `cv2`,
|
|
`pillow`, `None`. If backend is None, the global imread_backend
|
|
specified by ``mmcv.use_backend()`` will be used. Default: None.
|
|
|
|
Returns:
|
|
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
|
`resized_img`.
|
|
"""
|
|
h, w = img.shape[:2]
|
|
if backend is None:
|
|
backend = "cv2"
|
|
if backend not in ["cv2", "pillow"]:
|
|
raise ValueError(
|
|
f"backend: {backend} is not supported for resize."
|
|
f"Supported backends are 'cv2', 'pillow'"
|
|
)
|
|
|
|
if backend == "pillow":
|
|
assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
|
|
pil_image = Image.fromarray(img)
|
|
pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
|
|
resized_img = np.array(pil_image)
|
|
else:
|
|
resized_img = cv2.resize(
|
|
img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
|
|
)
|
|
if not return_scale:
|
|
return resized_img
|
|
else:
|
|
w_scale = size[0] / w
|
|
h_scale = size[1] / h
|
|
return resized_img, w_scale, h_scale
|
|
|
|
|
|
def rescale_size(old_size, scale, return_scale=False):
|
|
"""Calculate the new size to be rescaled to.
|
|
|
|
Args:
|
|
old_size (tuple[int]): The old size (w, h) of image.
|
|
scale (float | tuple[int]): The scaling factor or maximum size.
|
|
If it is a float number, then the image will be rescaled by this
|
|
factor, else if it is a tuple of 2 integers, then the image will
|
|
be rescaled as large as possible within the scale.
|
|
return_scale (bool): Whether to return the scaling factor besides the
|
|
rescaled image size.
|
|
|
|
Returns:
|
|
tuple[int]: The new rescaled image size.
|
|
"""
|
|
w, h = old_size
|
|
if isinstance(scale, (float, int)):
|
|
if scale <= 0:
|
|
raise ValueError(f"Invalid scale {scale}, must be positive.")
|
|
scale_factor = scale
|
|
elif isinstance(scale, tuple):
|
|
max_long_edge = max(scale)
|
|
max_short_edge = min(scale)
|
|
scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
|
|
else:
|
|
raise TypeError(
|
|
f"Scale must be a number or tuple of int, but got {type(scale)}"
|
|
)
|
|
|
|
new_size = _scale_size((w, h), scale_factor)
|
|
|
|
if return_scale:
|
|
return new_size, scale_factor
|
|
else:
|
|
return new_size
|
|
|
|
|
|
def _scale_size(size, scale):
|
|
"""Rescale a size by a ratio.
|
|
|
|
Args:
|
|
size (tuple[int]): (w, h).
|
|
scale (float | tuple(float)): Scaling factor.
|
|
|
|
Returns:
|
|
tuple[int]: scaled size.
|
|
"""
|
|
if isinstance(scale, (float, int)):
|
|
scale = (scale, scale)
|
|
w, h = size
|
|
return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
|
|
|
|
|
|
class VisTable:
|
|
def __init__(self):
|
|
self.load_img = LoadImage()
|
|
|
|
def __call__(
|
|
self,
|
|
img_path: Union[str, Path],
|
|
table_results,
|
|
save_html_path: Optional[Union[str, Path]] = None,
|
|
save_drawed_path: Optional[Union[str, Path]] = None,
|
|
save_logic_path: Optional[Union[str, Path]] = None,
|
|
):
|
|
if save_html_path:
|
|
html_with_border = self.insert_border_style(table_results.pred_html)
|
|
self.save_html(save_html_path, html_with_border)
|
|
|
|
table_cell_bboxes = table_results.cell_bboxes
|
|
table_logic_points = table_results.logic_points
|
|
if table_cell_bboxes is None:
|
|
return None
|
|
|
|
img = self.load_img(img_path)
|
|
|
|
dims_bboxes = table_cell_bboxes.shape[1]
|
|
if dims_bboxes == 4:
|
|
drawed_img = self.draw_rectangle(img, table_cell_bboxes)
|
|
elif dims_bboxes == 8:
|
|
drawed_img = self.draw_polylines(img, table_cell_bboxes)
|
|
else:
|
|
raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
|
|
|
|
if save_drawed_path:
|
|
self.save_img(save_drawed_path, drawed_img)
|
|
|
|
if save_logic_path:
|
|
polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
|
|
self.plot_rec_box_with_logic_info(
|
|
img, save_logic_path, table_logic_points, polygons
|
|
)
|
|
return drawed_img
|
|
|
|
def insert_border_style(self, table_html_str: str):
|
|
style_res = """<meta charset="UTF-8"><style>
|
|
table {
|
|
border-collapse: collapse;
|
|
width: 100%;
|
|
}
|
|
th, td {
|
|
border: 1px solid black;
|
|
padding: 8px;
|
|
text-align: center;
|
|
}
|
|
th {
|
|
background-color: #f2f2f2;
|
|
}
|
|
</style>"""
|
|
|
|
prefix_table, suffix_table = table_html_str.split("<body>")
|
|
html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
|
|
return html_with_border
|
|
|
|
def plot_rec_box_with_logic_info(
|
|
self, img, output_path, logic_points, sorted_polygons
|
|
):
|
|
"""
|
|
:param img_path
|
|
:param output_path
|
|
:param logic_points: [row_start,row_end,col_start,col_end]
|
|
:param sorted_polygons: [xmin,ymin,xmax,ymax]
|
|
:return:
|
|
"""
|
|
# 读取原图
|
|
img = cv2.copyMakeBorder(
|
|
img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
|
|
)
|
|
# 绘制 polygons 矩形
|
|
for idx, polygon in enumerate(sorted_polygons):
|
|
x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
|
|
x0 = round(x0)
|
|
y0 = round(y0)
|
|
x1 = round(x1)
|
|
y1 = round(y1)
|
|
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
|
|
# 增大字体大小和线宽
|
|
font_scale = 0.9 # 原先是0.5
|
|
thickness = 1 # 原先是1
|
|
logic_point = logic_points[idx]
|
|
cv2.putText(
|
|
img,
|
|
f"row: {logic_point[0]}-{logic_point[1]}",
|
|
(x0 + 3, y0 + 8),
|
|
cv2.FONT_HERSHEY_PLAIN,
|
|
font_scale,
|
|
(0, 0, 255),
|
|
thickness,
|
|
)
|
|
cv2.putText(
|
|
img,
|
|
f"col: {logic_point[2]}-{logic_point[3]}",
|
|
(x0 + 3, y0 + 18),
|
|
cv2.FONT_HERSHEY_PLAIN,
|
|
font_scale,
|
|
(0, 0, 255),
|
|
thickness,
|
|
)
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
# 保存绘制后的图像
|
|
self.save_img(output_path, img)
|
|
|
|
@staticmethod
|
|
def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
|
|
img_copy = img.copy()
|
|
for box in boxes.astype(int):
|
|
x1, y1, x2, y2 = box
|
|
cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
|
return img_copy
|
|
|
|
@staticmethod
|
|
def draw_polylines(img: np.ndarray, points) -> np.ndarray:
|
|
img_copy = img.copy()
|
|
for point in points.astype(int):
|
|
point = point.reshape(4, 2)
|
|
cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
|
|
return img_copy
|
|
|
|
@staticmethod
|
|
def save_img(save_path: Union[str, Path], img: np.ndarray):
|
|
cv2.imwrite(str(save_path), img)
|
|
|
|
@staticmethod
|
|
def save_html(save_path: Union[str, Path], html: str):
|
|
with open(save_path, "w", encoding="utf-8") as f:
|
|
f.write(html)
|