301 lines
9.0 KiB
Python
301 lines
9.0 KiB
Python
# Copyright (c) Opendatalab. All rights reserved.
|
||
import copy
|
||
import os
|
||
import warnings
|
||
from pathlib import Path
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import yaml
|
||
from loguru import logger
|
||
|
||
from mineru.utils.config_reader import get_device
|
||
from mineru.utils.enum_class import ModelPath
|
||
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
||
from mineru.utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
|
||
from mineru.model.utils.tools.infer.predict_system import TextSystem
|
||
from mineru.model.utils.tools.infer import pytorchocr_utility as utility
|
||
import argparse
|
||
|
||
|
||
latin_lang = [
|
||
"af",
|
||
"az",
|
||
"bs",
|
||
"cs",
|
||
"cy",
|
||
"da",
|
||
"de",
|
||
"es",
|
||
"et",
|
||
"fr",
|
||
"ga",
|
||
"hr",
|
||
"hu",
|
||
"id",
|
||
"is",
|
||
"it",
|
||
"ku",
|
||
"la",
|
||
"lt",
|
||
"lv",
|
||
"mi",
|
||
"ms",
|
||
"mt",
|
||
"nl",
|
||
"no",
|
||
"oc",
|
||
"pi",
|
||
"pl",
|
||
"pt",
|
||
"ro",
|
||
"rs_latin",
|
||
"sk",
|
||
"sl",
|
||
"sq",
|
||
"sv",
|
||
"sw",
|
||
"tl",
|
||
"tr",
|
||
"uz",
|
||
"vi",
|
||
"french",
|
||
"german",
|
||
"fi",
|
||
"eu",
|
||
"gl",
|
||
"lb",
|
||
"rm",
|
||
"ca",
|
||
"qu",
|
||
]
|
||
arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"]
|
||
cyrillic_lang = [
|
||
"ru",
|
||
"rs_cyrillic",
|
||
"be",
|
||
"bg",
|
||
"uk",
|
||
"mn",
|
||
"abq",
|
||
"ady",
|
||
"kbd",
|
||
"ava",
|
||
"dar",
|
||
"inh",
|
||
"che",
|
||
"lbe",
|
||
"lez",
|
||
"tab",
|
||
"kk",
|
||
"ky",
|
||
"tg",
|
||
"mk",
|
||
"tt",
|
||
"cv",
|
||
"ba",
|
||
"mhr",
|
||
"mo",
|
||
"udm",
|
||
"kv",
|
||
"os",
|
||
"bua",
|
||
"xal",
|
||
"tyv",
|
||
"sah",
|
||
"kaa",
|
||
]
|
||
east_slavic_lang = ["ru", "be", "uk"]
|
||
devanagari_lang = [
|
||
"hi",
|
||
"mr",
|
||
"ne",
|
||
"bh",
|
||
"mai",
|
||
"ang",
|
||
"bho",
|
||
"mah",
|
||
"sck",
|
||
"new",
|
||
"gom",
|
||
"sa",
|
||
"bgc",
|
||
]
|
||
|
||
|
||
def get_model_params(lang, config):
|
||
if lang in config['lang']:
|
||
params = config['lang'][lang]
|
||
det = params.get('det')
|
||
rec = params.get('rec')
|
||
dict_file = params.get('dict')
|
||
return det, rec, dict_file
|
||
else:
|
||
raise Exception (f'Language {lang} not supported')
|
||
|
||
|
||
root_dir = os.path.join(Path(__file__).resolve().parent.parent, 'utils')
|
||
|
||
|
||
class PytorchPaddleOCR(TextSystem):
|
||
def __init__(self, *args, **kwargs):
|
||
parser = utility.init_args()
|
||
args = parser.parse_args(args)
|
||
|
||
self.lang = kwargs.get('lang', 'ch')
|
||
self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
|
||
|
||
device = get_device()
|
||
if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
|
||
# logger.warning("The current device in use is CPU. To ensure the speed of parsing, the language is automatically switched to ch_lite.")
|
||
self.lang = 'ch_lite'
|
||
|
||
if self.lang in latin_lang:
|
||
self.lang = 'latin'
|
||
elif self.lang in east_slavic_lang:
|
||
self.lang = 'east_slavic'
|
||
elif self.lang in arabic_lang:
|
||
self.lang = 'arabic'
|
||
elif self.lang in cyrillic_lang:
|
||
self.lang = 'cyrillic'
|
||
elif self.lang in devanagari_lang:
|
||
self.lang = 'devanagari'
|
||
else:
|
||
pass
|
||
|
||
models_config_path = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'models_config.yml')
|
||
with open(models_config_path) as file:
|
||
config = yaml.safe_load(file)
|
||
det, rec, dict_file = get_model_params(self.lang, config)
|
||
ocr_models_dir = ModelPath.pytorch_paddle
|
||
|
||
det_model_path = f"{ocr_models_dir}/{det}"
|
||
det_model_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
|
||
rec_model_path = f"{ocr_models_dir}/{rec}"
|
||
rec_model_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
|
||
kwargs['det_model_path'] = det_model_path
|
||
kwargs['rec_model_path'] = rec_model_path
|
||
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
|
||
kwargs['rec_batch_num'] = 6
|
||
|
||
kwargs['device'] = device
|
||
|
||
default_args = vars(args)
|
||
default_args.update(kwargs)
|
||
args = argparse.Namespace(**default_args)
|
||
|
||
super().__init__(args)
|
||
|
||
def ocr(self,
|
||
img,
|
||
det=True,
|
||
rec=True,
|
||
mfd_res=None,
|
||
tqdm_enable=False,
|
||
tqdm_desc="OCR-rec Predict",
|
||
):
|
||
assert isinstance(img, (np.ndarray, list, str, bytes))
|
||
if isinstance(img, list) and det == True:
|
||
logger.error('When input a list of images, det must be false')
|
||
exit(0)
|
||
img = check_img(img)
|
||
imgs = [img]
|
||
with warnings.catch_warnings():
|
||
warnings.simplefilter("ignore", category=RuntimeWarning)
|
||
if det and rec:
|
||
ocr_res = []
|
||
for img in imgs:
|
||
img = preprocess_image(img)
|
||
dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
|
||
if not dt_boxes and not rec_res:
|
||
ocr_res.append(None)
|
||
continue
|
||
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||
ocr_res.append(tmp_res)
|
||
return ocr_res
|
||
elif det and not rec:
|
||
ocr_res = []
|
||
for img in imgs:
|
||
img = preprocess_image(img)
|
||
dt_boxes, elapse = self.text_detector(img)
|
||
# logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
|
||
if dt_boxes is None:
|
||
ocr_res.append(None)
|
||
continue
|
||
dt_boxes = sorted_boxes(dt_boxes)
|
||
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
|
||
if self.enable_merge_det_boxes:
|
||
dt_boxes = merge_det_boxes(dt_boxes)
|
||
if mfd_res:
|
||
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
||
tmp_res = [box.tolist() for box in dt_boxes]
|
||
ocr_res.append(tmp_res)
|
||
return ocr_res
|
||
elif not det and rec:
|
||
ocr_res = []
|
||
for img in imgs:
|
||
if not isinstance(img, list):
|
||
img = preprocess_image(img)
|
||
img = [img]
|
||
rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc)
|
||
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
|
||
ocr_res.append(rec_res)
|
||
return ocr_res
|
||
|
||
def __call__(self, img, mfd_res=None):
|
||
|
||
if img is None:
|
||
logger.debug("no valid image provided")
|
||
return None, None
|
||
|
||
ori_im = img.copy()
|
||
dt_boxes, elapse = self.text_detector(img)
|
||
|
||
if dt_boxes is None:
|
||
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
|
||
return None, None
|
||
else:
|
||
pass
|
||
# logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
|
||
img_crop_list = []
|
||
|
||
dt_boxes = sorted_boxes(dt_boxes)
|
||
|
||
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
|
||
if self.enable_merge_det_boxes:
|
||
dt_boxes = merge_det_boxes(dt_boxes)
|
||
|
||
if mfd_res:
|
||
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
||
|
||
for bno in range(len(dt_boxes)):
|
||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
||
img_crop_list.append(img_crop)
|
||
|
||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
|
||
|
||
filter_boxes, filter_rec_res = [], []
|
||
for box, rec_result in zip(dt_boxes, rec_res):
|
||
text, score = rec_result
|
||
if score >= self.drop_score:
|
||
filter_boxes.append(box)
|
||
filter_rec_res.append(rec_result)
|
||
|
||
return filter_boxes, filter_rec_res
|
||
|
||
if __name__ == '__main__':
|
||
pytorch_paddle_ocr = PytorchPaddleOCR()
|
||
img = cv2.imread("/Users/myhloli/Downloads/screenshot-20250326-194348.png")
|
||
dt_boxes, rec_res = pytorch_paddle_ocr(img)
|
||
ocr_res = []
|
||
if not dt_boxes and not rec_res:
|
||
ocr_res.append(None)
|
||
else:
|
||
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||
ocr_res.append(tmp_res)
|
||
print(ocr_res)
|
||
|
||
|