diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index 826d20d1..0c0e664b 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -1,203 +1,28 @@ +import numpy as np +import torch from loguru import logger import os import time -from magic_pdf.libs.Constants import * -from magic_pdf.libs.clean_memory import clean_memory -from magic_pdf.model.model_list import AtomicModel +import cv2 +import yaml +from PIL import Image os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger + try: - import cv2 - import yaml - import argparse - import numpy as np - import torch import torchtext if torchtext.__version__ >= "0.18.0": torchtext.disable_torchtext_deprecation_warning() - from PIL import Image - from torchvision import transforms - from torch.utils.data import Dataset, DataLoader - from ultralytics import YOLO - from unimernet.common.config import Config - import unimernet.tasks as tasks - from unimernet.processors import load_processor - from doclayout_yolo import YOLOv10 - from rapid_table import RapidTable - from rapidocr_paddle import RapidOCR +except ImportError: + pass -except ImportError as e: - logger.exception(e) - logger.error( - 'Required dependency not installed, please install by \n' - '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"') - exit(1) - -from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor -from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace -from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR -from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel -from magic_pdf.model.ppTableModel import ppTableModel - - -def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): - ocr_engine = None - if table_model_type == MODEL_NAME.STRUCT_EQTABLE: - table_model = StructTableModel(model_path, max_time=max_time) - elif table_model_type == MODEL_NAME.TABLE_MASTER: - config = { - "model_dir": model_path, - "device": _device_ - } - table_model = ppTableModel(config) - elif table_model_type == MODEL_NAME.RAPID_TABLE: - table_model = RapidTable() - ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True) - else: - logger.error("table model type not allow") - exit(1) - - if ocr_engine: - return [table_model, ocr_engine] - else: - return table_model - - -def mfd_model_init(weight): - mfd_model = YOLO(weight) - return mfd_model - - -def mfr_model_init(weight_dir, cfg_path, _device_='cpu'): - args = argparse.Namespace(cfg_path=cfg_path, options=None) - cfg = Config(args) - cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth") - cfg.config.model.model_config.model_name = weight_dir - cfg.config.model.tokenizer_config.path = weight_dir - task = tasks.setup_task(cfg) - model = task.build_model(cfg) - model.to(_device_) - model.eval() - vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) - mfr_transform = transforms.Compose([vis_processor, ]) - return [model, mfr_transform] - - -def layout_model_init(weight, config_file, device): - model = Layoutlmv3_Predictor(weight, config_file, device) - return model - - -def doclayout_yolo_model_init(weight): - model = YOLOv10(weight) - return model - - -def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8): - if lang is not None: - model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio) - else: - model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio) - return model - - -class MathDataset(Dataset): - def __init__(self, image_paths, transform=None): - self.image_paths = image_paths - self.transform = transform - - def __len__(self): - return len(self.image_paths) - - def __getitem__(self, idx): - # if not pil image, then convert to pil image - if isinstance(self.image_paths[idx], str): - raw_image = Image.open(self.image_paths[idx]) - else: - raw_image = self.image_paths[idx] - if self.transform: - image = self.transform(raw_image) - return image - - -class AtomModelSingleton: - _instance = None - _models = {} - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def get_atom_model(self, atom_model_name: str, **kwargs): - lang = kwargs.get("lang", None) - layout_model_name = kwargs.get("layout_model_name", None) - key = (atom_model_name, layout_model_name, lang) - if key not in self._models: - self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) - return self._models[key] - - -def atom_model_init(model_name: str, **kwargs): - - if model_name == AtomicModel.Layout: - if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3: - atom_model = layout_model_init( - kwargs.get("layout_weights"), - kwargs.get("layout_config_file"), - kwargs.get("device") - ) - elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO: - atom_model = doclayout_yolo_model_init( - kwargs.get("doclayout_yolo_weights"), - ) - elif model_name == AtomicModel.MFD: - atom_model = mfd_model_init( - kwargs.get("mfd_weights") - ) - elif model_name == AtomicModel.MFR: - atom_model = mfr_model_init( - kwargs.get("mfr_weight_dir"), - kwargs.get("mfr_cfg_path"), - kwargs.get("device") - ) - elif model_name == AtomicModel.OCR: - atom_model = ocr_model_init( - kwargs.get("ocr_show_log"), - kwargs.get("det_db_box_thresh"), - kwargs.get("lang") - ) - elif model_name == AtomicModel.Table: - atom_model = table_model_init( - kwargs.get("table_model_name"), - kwargs.get("table_model_path"), - kwargs.get("table_max_time"), - kwargs.get("device") - ) - else: - logger.error("model name not allow") - exit(1) - - return atom_model - - -# Unified crop img logic -def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0): - crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) - crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) - # Create a white background with an additional width and height of 50 - crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2 - crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2 - return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white') - - # Crop image - crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax) - cropped_img = input_pil_img.crop(crop_box) - return_image.paste(cropped_img, (crop_paste_x, crop_paste_y)) - return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height] - return return_image, return_list +from magic_pdf.libs.Constants import * +from magic_pdf.model.model_list import AtomicModel +from magic_pdf.model.sub_modules.model_init import AtomModelSingleton +from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram +from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list class CustomPEKModel: @@ -243,7 +68,8 @@ class CustomPEKModel: logger.info( "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, " "apply_table: {}, table_model: {}, lang: {}".format( - self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang + self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, + self.lang ) ) # 初始化解析方案 @@ -256,17 +82,17 @@ class CustomPEKModel: # 初始化公式识别 if self.apply_formula: - # 初始化公式检测模型 self.mfd_model = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.MFD, - mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])) + mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])), + device=self.device ) # 初始化公式解析模型 mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name])) mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml")) - self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model( + self.mfr_model = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.MFR, mfr_weight_dir=mfr_weight_dir, mfr_cfg_path=mfr_cfg_path, @@ -286,7 +112,8 @@ class CustomPEKModel: self.layout_model = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.Layout, layout_model_name=MODEL_NAME.DocLayout_YOLO, - doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])) + doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])), + device=self.device ) # 初始化ocr if self.apply_ocr: @@ -299,22 +126,13 @@ class CustomPEKModel: # init table model if self.apply_table: table_model_dir = self.configs["weights"][self.table_model_name] - if self.table_model_name in [MODEL_NAME.STRUCT_EQTABLE, MODEL_NAME.TABLE_MASTER]: - self.table_model = atom_model_manager.get_atom_model( - atom_model_name=AtomicModel.Table, - table_model_name=self.table_model_name, - table_model_path=str(os.path.join(models_dir, table_model_dir)), - table_max_time=self.table_max_time, - device=self.device - ) - elif self.table_model_name in [MODEL_NAME.RAPID_TABLE]: - self.table_model, self.ocr_engine =atom_model_manager.get_atom_model( - atom_model_name=AtomicModel.Table, - table_model_name=self.table_model_name, - table_model_path=str(os.path.join(models_dir, table_model_dir)), - table_max_time=self.table_max_time, - device=self.device - ) + self.table_model = atom_model_manager.get_atom_model( + atom_model_name=AtomicModel.Table, + table_model_name=self.table_model_name, + table_model_path=str(os.path.join(models_dir, table_model_dir)), + table_max_time=self.table_max_time, + device=self.device + ) logger.info('DocAnalysis init done!') @@ -322,26 +140,15 @@ class CustomPEKModel: page_start = time.time() - latex_filling_list = [] - mf_image_list = [] - # layout检测 layout_start = time.time() + layout_res = [] if self.layout_model_name == MODEL_NAME.LAYOUTLMv3: # layoutlmv3 layout_res = self.layout_model(image, ignore_catids=[]) elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: # doclayout_yolo - layout_res = [] - doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] - for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()): - xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] - new_item = { - 'category_id': int(cla.item()), - 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], - 'score': round(float(conf.item()), 3), - } - layout_res.append(new_item) + layout_res = self.layout_model.predict(image) layout_cost = round(time.time() - layout_start, 2) logger.info(f"layout detection time: {layout_cost}") @@ -350,58 +157,21 @@ class CustomPEKModel: if self.apply_formula: # 公式检测 mfd_start = time.time() - mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] + mfd_res = self.mfd_model.predict(image) logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}") - for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): - xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] - new_item = { - 'category_id': 13 + int(cla.item()), - 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], - 'score': round(float(conf.item()), 2), - 'latex': '', - } - layout_res.append(new_item) - latex_filling_list.append(new_item) - bbox_img = pil_img.crop((xmin, ymin, xmax, ymax)) - mf_image_list.append(bbox_img) # 公式识别 mfr_start = time.time() - dataset = MathDataset(mf_image_list, transform=self.mfr_transform) - dataloader = DataLoader(dataset, batch_size=64, num_workers=0) - mfr_res = [] - for mf_img in dataloader: - mf_img = mf_img.to(self.device) - with torch.no_grad(): - output = self.mfr_model.generate({'image': mf_img}) - mfr_res.extend(output['pred_str']) - for res, latex in zip(latex_filling_list, mfr_res): - res['latex'] = latex_rm_whitespace(latex) + formula_list = self.mfr_model.predict(mfd_res, image) + layout_res.extend(formula_list) mfr_cost = round(time.time() - mfr_start, 2) - logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}") + logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}") - # Select regions for OCR / formula regions / table regions - ocr_res_list = [] - table_res_list = [] - single_page_mfdetrec_res = [] - for res in layout_res: - if int(res['category_id']) in [13, 14]: - single_page_mfdetrec_res.append({ - "bbox": [int(res['poly'][0]), int(res['poly'][1]), - int(res['poly'][4]), int(res['poly'][5])], - }) - elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]: - ocr_res_list.append(res) - elif int(res['category_id']) in [5]: - table_res_list.append(res) + # 清理显存 + clean_vram(self.device, vram_threshold=8) - if torch.cuda.is_available() and self.device != 'cpu': - total_memory = torch.cuda.get_device_properties(self.device).total_memory / (1024 ** 3) # 将字节转换为 GB - if total_memory <= 8: - gc_start = time.time() - clean_memory() - gc_time = round(time.time() - gc_start, 2) - logger.info(f"gc time: {gc_time}") + # 从layout_res中获取ocr区域、表格区域、公式区域 + ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res) # ocr识别 if self.apply_ocr: @@ -409,23 +179,7 @@ class CustomPEKModel: # Process each area that requires OCR processing for res in ocr_res_list: new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50) - paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list - # Adjust the coordinates of the formula area - adjusted_mfdetrec_res = [] - for mf_res in single_page_mfdetrec_res: - mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] - # Adjust the coordinates of the formula area to the coordinates relative to the cropping area - x0 = mf_xmin - xmin + paste_x - y0 = mf_ymin - ymin + paste_y - x1 = mf_xmax - xmin + paste_x - y1 = mf_ymax - ymin + paste_y - # Filter formula blocks outside the graph - if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): - continue - else: - adjusted_mfdetrec_res.append({ - "bbox": [x0, y0, x1, y1], - }) + adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list) # OCR recognition new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) @@ -433,22 +187,8 @@ class CustomPEKModel: # Integration results if ocr_res: - for box_ocr_res in ocr_res: - p1, p2, p3, p4 = box_ocr_res[0] - text, score = box_ocr_res[1] - - # Convert the coordinates back to the original coordinate system - p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] - p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] - p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] - p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] - - layout_res.append({ - 'category_id': 15, - 'poly': p1 + p2 + p3 + p4, - 'score': round(score, 2), - 'text': text, - }) + ocr_result_list = get_ocr_result_list(ocr_res, useful_list) + layout_res.extend(ocr_result_list) ocr_cost = round(time.time() - ocr_start, 2) logger.info(f"ocr time: {ocr_cost}") @@ -459,8 +199,6 @@ class CustomPEKModel: for res in table_res_list: new_image, _ = crop_img(res, pil_img) single_table_start_time = time.time() - # logger.info("------------------table recognition processing begins-----------------") - latex_code = None html_code = None if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: with torch.no_grad(): @@ -470,33 +208,21 @@ class CustomPEKModel: elif self.table_model_name == MODEL_NAME.TABLE_MASTER: html_code = self.table_model.img2html(new_image) elif self.table_model_name == MODEL_NAME.RAPID_TABLE: - ocr_result, _ = self.ocr_engine(np.asarray(new_image)) - html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), ocr_result) - + html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image) run_time = time.time() - single_table_start_time - # logger.info(f"------------table recognition processing ends within {run_time}s-----") if run_time > self.table_max_time: - logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------") + logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s") # 判断是否返回正常 - - if latex_code: - expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}') - if expected_ending: - res["latex"] = latex_code - else: - logger.warning(f"table recognition processing fails, not found expected LaTeX table end") - elif html_code: + if html_code: expected_ending = html_code.strip().endswith('') or html_code.strip().endswith('') if expected_ending: res["html"] = html_code else: logger.warning(f"table recognition processing fails, not found expected HTML table end") else: - logger.warning(f"table recognition processing fails, not get latex or html return") + logger.warning(f"table recognition processing fails, not get html return") logger.info(f"table time: {round(time.time() - table_start, 2)}") logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----") return layout_res - - diff --git a/magic_pdf/model/pek_sub_modules/post_process.py b/magic_pdf/model/pek_sub_modules/post_process.py deleted file mode 100644 index aa050b61..00000000 --- a/magic_pdf/model/pek_sub_modules/post_process.py +++ /dev/null @@ -1,36 +0,0 @@ -import re - -def layout_rm_equation(layout_res): - rm_idxs = [] - for idx, ele in enumerate(layout_res['layout_dets']): - if ele['category_id'] == 10: - rm_idxs.append(idx) - - for idx in rm_idxs[::-1]: - del layout_res['layout_dets'][idx] - return layout_res - - -def get_croped_image(image_pil, bbox): - x_min, y_min, x_max, y_max = bbox - croped_img = image_pil.crop((x_min, y_min, x_max, y_max)) - return croped_img - - -def latex_rm_whitespace(s: str): - """Remove unnecessary whitespace from LaTeX code. - """ - text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})' - letter = '[a-zA-Z]' - noletter = '[\W_^\d]' - names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)] - s = re.sub(text_reg, lambda match: str(names.pop(0)), s) - news = s - while True: - s = news - news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s) - news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news) - news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news) - if news == s: - break - return s \ No newline at end of file diff --git a/magic_pdf/model/pek_sub_modules/self_modify.py b/magic_pdf/model/pek_sub_modules/self_modify.py deleted file mode 100644 index c39c1a9c..00000000 --- a/magic_pdf/model/pek_sub_modules/self_modify.py +++ /dev/null @@ -1,388 +0,0 @@ -import time -import copy -import base64 -import cv2 -import numpy as np -from io import BytesIO -from PIL import Image - -from paddleocr import PaddleOCR -from paddleocr.ppocr.utils.logging import get_logger -from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img -from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop - -from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold -from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line - -logger = get_logger() - - -def img_decode(content: bytes): - np_arr = np.frombuffer(content, dtype=np.uint8) - return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) - - -def check_img(img): - if isinstance(img, bytes): - img = img_decode(img) - if isinstance(img, str): - image_file = img - img, flag_gif, flag_pdf = check_and_read(image_file) - if not flag_gif and not flag_pdf: - with open(image_file, 'rb') as f: - img_str = f.read() - img = img_decode(img_str) - if img is None: - try: - buf = BytesIO() - image = BytesIO(img_str) - im = Image.open(image) - rgb = im.convert('RGB') - rgb.save(buf, 'jpeg') - buf.seek(0) - image_bytes = buf.read() - data_base64 = str(base64.b64encode(image_bytes), - encoding="utf-8") - image_decode = base64.b64decode(data_base64) - img_array = np.frombuffer(image_decode, np.uint8) - img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) - except: - logger.error("error in loading image:{}".format(image_file)) - return None - if img is None: - logger.error("error in loading image:{}".format(image_file)) - return None - if isinstance(img, np.ndarray) and len(img.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - return img - - -def sorted_boxes(dt_boxes): - """ - Sort text boxes in order from top to bottom, left to right - args: - dt_boxes(array):detected text boxes with shape [4, 2] - return: - sorted boxes(array) with shape [4, 2] - """ - num_boxes = dt_boxes.shape[0] - sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) - _boxes = list(sorted_boxes) - - for i in range(num_boxes - 1): - for j in range(i, -1, -1): - if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \ - (_boxes[j + 1][0][0] < _boxes[j][0][0]): - tmp = _boxes[j] - _boxes[j] = _boxes[j + 1] - _boxes[j + 1] = tmp - else: - break - return _boxes - - -def bbox_to_points(bbox): - """ 将bbox格式转换为四个顶点的数组 """ - x0, y0, x1, y1 = bbox - return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32') - - -def points_to_bbox(points): - """ 将四个顶点的数组转换为bbox格式 """ - x0, y0 = points[0] - x1, _ = points[1] - _, y1 = points[2] - return [x0, y0, x1, y1] - - -def merge_intervals(intervals): - # Sort the intervals based on the start value - intervals.sort(key=lambda x: x[0]) - - merged = [] - for interval in intervals: - # If the list of merged intervals is empty or if the current - # interval does not overlap with the previous, simply append it. - if not merged or merged[-1][1] < interval[0]: - merged.append(interval) - else: - # Otherwise, there is overlap, so we merge the current and previous intervals. - merged[-1][1] = max(merged[-1][1], interval[1]) - - return merged - - -def remove_intervals(original, masks): - # Merge all mask intervals - merged_masks = merge_intervals(masks) - - result = [] - original_start, original_end = original - - for mask in merged_masks: - mask_start, mask_end = mask - - # If the mask starts after the original range, ignore it - if mask_start > original_end: - continue - - # If the mask ends before the original range starts, ignore it - if mask_end < original_start: - continue - - # Remove the masked part from the original range - if original_start < mask_start: - result.append([original_start, mask_start - 1]) - - original_start = max(mask_end + 1, original_start) - - # Add the remaining part of the original range, if any - if original_start <= original_end: - result.append([original_start, original_end]) - - return result - - -def update_det_boxes(dt_boxes, mfd_res): - new_dt_boxes = [] - for text_box in dt_boxes: - text_bbox = points_to_bbox(text_box) - masks_list = [] - for mf_box in mfd_res: - mf_bbox = mf_box['bbox'] - if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox): - masks_list.append([mf_bbox[0], mf_bbox[2]]) - text_x_range = [text_bbox[0], text_bbox[2]] - text_remove_mask_range = remove_intervals(text_x_range, masks_list) - temp_dt_box = [] - for text_remove_mask in text_remove_mask_range: - temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]])) - if len(temp_dt_box) > 0: - new_dt_boxes.extend(temp_dt_box) - return new_dt_boxes - - -def merge_overlapping_spans(spans): - """ - Merges overlapping spans on the same line. - - :param spans: A list of span coordinates [(x1, y1, x2, y2), ...] - :return: A list of merged spans - """ - # Return an empty list if the input spans list is empty - if not spans: - return [] - - # Sort spans by their starting x-coordinate - spans.sort(key=lambda x: x[0]) - - # Initialize the list of merged spans - merged = [] - for span in spans: - # Unpack span coordinates - x1, y1, x2, y2 = span - # If the merged list is empty or there's no horizontal overlap, add the span directly - if not merged or merged[-1][2] < x1: - merged.append(span) - else: - # If there is horizontal overlap, merge the current span with the previous one - last_span = merged.pop() - # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2) - x1 = min(last_span[0], x1) - y1 = min(last_span[1], y1) - x2 = max(last_span[2], x2) - y2 = max(last_span[3], y2) - # Add the merged span back to the list - merged.append((x1, y1, x2, y2)) - - # Return the list of merged spans - return merged - - -def merge_det_boxes(dt_boxes): - """ - Merge detection boxes. - - This function takes a list of detected bounding boxes, each represented by four corner points. - The goal is to merge these bounding boxes into larger text regions. - - Parameters: - dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points. - - Returns: - list: A list containing the merged text regions, where each region is represented by four corner points. - """ - # Convert the detection boxes into a dictionary format with bounding boxes and type - dt_boxes_dict_list = [] - for text_box in dt_boxes: - text_bbox = points_to_bbox(text_box) - text_box_dict = { - 'bbox': text_bbox, - 'type': 'text', - } - dt_boxes_dict_list.append(text_box_dict) - - # Merge adjacent text regions into lines - lines = merge_spans_to_line(dt_boxes_dict_list) - - # Initialize a new list for storing the merged text regions - new_dt_boxes = [] - for line in lines: - line_bbox_list = [] - for span in line: - line_bbox_list.append(span['bbox']) - - # Merge overlapping text regions within the same line - merged_spans = merge_overlapping_spans(line_bbox_list) - - # Convert the merged text regions back to point format and add them to the new detection box list - for span in merged_spans: - new_dt_boxes.append(bbox_to_points(span)) - - return new_dt_boxes - - -class ModifiedPaddleOCR(PaddleOCR): - def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)): - """ - OCR with PaddleOCR - args: - img: img for OCR, support ndarray, img_path and list or ndarray - det: use text detection or not. If False, only rec will be exec. Default is True - rec: use text recognition or not. If False, only det will be exec. Default is True - cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False. - bin: binarize image to black and white. Default is False. - inv: invert image colors. Default is False. - alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white. - """ - assert isinstance(img, (np.ndarray, list, str, bytes)) - if isinstance(img, list) and det == True: - logger.error('When input a list of images, det must be false') - exit(0) - if cls == True and self.use_angle_cls == False: - pass - # logger.warning( - # 'Since the angle classifier is not initialized, it will not be used during the forward process' - # ) - - img = check_img(img) - # for infer pdf file - if isinstance(img, list): - if self.page_num > len(img) or self.page_num == 0: - self.page_num = len(img) - imgs = img[:self.page_num] - else: - imgs = [img] - - def preprocess_image(_image): - _image = alpha_to_color(_image, alpha_color) - if inv: - _image = cv2.bitwise_not(_image) - if bin: - _image = binarize_img(_image) - return _image - - if det and rec: - ocr_res = [] - for idx, img in enumerate(imgs): - img = preprocess_image(img) - dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res) - if not dt_boxes and not rec_res: - ocr_res.append(None) - continue - tmp_res = [[box.tolist(), res] - for box, res in zip(dt_boxes, rec_res)] - ocr_res.append(tmp_res) - return ocr_res - elif det and not rec: - ocr_res = [] - for idx, img in enumerate(imgs): - img = preprocess_image(img) - dt_boxes, elapse = self.text_detector(img) - if not dt_boxes: - ocr_res.append(None) - continue - tmp_res = [box.tolist() for box in dt_boxes] - ocr_res.append(tmp_res) - return ocr_res - else: - ocr_res = [] - cls_res = [] - for idx, img in enumerate(imgs): - if not isinstance(img, list): - img = preprocess_image(img) - img = [img] - if self.use_angle_cls and cls: - img, cls_res_tmp, elapse = self.text_classifier(img) - if not rec: - cls_res.append(cls_res_tmp) - rec_res, elapse = self.text_recognizer(img) - ocr_res.append(rec_res) - if not rec: - return cls_res - return ocr_res - - def __call__(self, img, cls=True, mfd_res=None): - time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} - - if img is None: - logger.debug("no valid image provided") - return None, None, time_dict - - start = time.time() - ori_im = img.copy() - dt_boxes, elapse = self.text_detector(img) - time_dict['det'] = elapse - - if dt_boxes is None: - logger.debug("no dt_boxes found, elapsed : {}".format(elapse)) - end = time.time() - time_dict['all'] = end - start - return None, None, time_dict - else: - logger.debug("dt_boxes num : {}, elapsed : {}".format( - len(dt_boxes), elapse)) - img_crop_list = [] - - dt_boxes = sorted_boxes(dt_boxes) - - dt_boxes = merge_det_boxes(dt_boxes) - - if mfd_res: - bef = time.time() - dt_boxes = update_det_boxes(dt_boxes, mfd_res) - aft = time.time() - logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format( - len(dt_boxes), aft - bef)) - - for bno in range(len(dt_boxes)): - tmp_box = copy.deepcopy(dt_boxes[bno]) - if self.args.det_box_type == "quad": - img_crop = get_rotate_crop_image(ori_im, tmp_box) - else: - img_crop = get_minarea_rect_crop(ori_im, tmp_box) - img_crop_list.append(img_crop) - if self.use_angle_cls and cls: - img_crop_list, angle_list, elapse = self.text_classifier( - img_crop_list) - time_dict['cls'] = elapse - logger.debug("cls num : {}, elapsed : {}".format( - len(img_crop_list), elapse)) - - rec_res, elapse = self.text_recognizer(img_crop_list) - time_dict['rec'] = elapse - logger.debug("rec_res num : {}, elapsed : {}".format( - len(rec_res), elapse)) - if self.args.save_crop_res: - self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, - rec_res) - filter_boxes, filter_rec_res = [], [] - for box, rec_result in zip(dt_boxes, rec_res): - text, score = rec_result - if score >= self.drop_score: - filter_boxes.append(box) - filter_rec_res.append(rec_result) - end = time.time() - time_dict['all'] = end - start - return filter_boxes, filter_rec_res, time_dict \ No newline at end of file diff --git a/magic_pdf/model/pek_sub_modules/__init__.py b/magic_pdf/model/sub_modules/__init__.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/__init__.py rename to magic_pdf/model/sub_modules/__init__.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py b/magic_pdf/model/sub_modules/layout/__init__.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py rename to magic_pdf/model/sub_modules/layout/__init__.py diff --git a/magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py b/magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py new file mode 100644 index 00000000..ab38bf07 --- /dev/null +++ b/magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py @@ -0,0 +1,21 @@ +from doclayout_yolo import YOLOv10 + + +class DocLayoutYOLOModel(object): + def __init__(self, weight, device): + self.model = YOLOv10(weight) + self.device = device + + def predict(self, image): + layout_res = [] + doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] + for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), + doclayout_yolo_res.boxes.cls.cpu()): + xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] + new_item = { + 'category_id': int(cla.item()), + 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], + 'score': round(float(conf.item()), 3), + } + layout_res.append(new_item) + return layout_res \ No newline at end of file diff --git a/magic_pdf/model/pek_sub_modules/structeqtable/__init__.py b/magic_pdf/model/sub_modules/layout/doclayout_yolo/__init__.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/structeqtable/__init__.py rename to magic_pdf/model/sub_modules/layout/doclayout_yolo/__init__.py diff --git a/magic_pdf/model/v3/__init__.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/__init__.py similarity index 100% rename from magic_pdf/model/v3/__init__.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/__init__.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/backbone.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/backbone.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/beit.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/beit.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/deit.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/deit.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/__init__.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/__init__.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/__init__.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/__init__.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/cord.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/cord.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/data_collator.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/data_collator.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/funsd.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/funsd.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/image_utils.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/image_utils.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/xfund.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/data/xfund.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/__init__.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/__init__.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/model_init.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/model_init.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/rcnn_vl.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/rcnn_vl.py diff --git a/magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py b/magic_pdf/model/sub_modules/layout/layoutlmv3/visualizer.py similarity index 100% rename from magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py rename to magic_pdf/model/sub_modules/layout/layoutlmv3/visualizer.py diff --git a/magic_pdf/model/sub_modules/mfd/__init__.py b/magic_pdf/model/sub_modules/mfd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py b/magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py new file mode 100644 index 00000000..594df265 --- /dev/null +++ b/magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py @@ -0,0 +1,12 @@ +from ultralytics import YOLO + + +class YOLOv8MFDModel(object): + def __init__(self, weight, device='cpu'): + self.mfd_model = YOLO(weight) + self.device = device + + def predict(self, image): + mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] + return mfd_res + diff --git a/magic_pdf/model/sub_modules/mfd/yolov8/__init__.py b/magic_pdf/model/sub_modules/mfd/yolov8/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/sub_modules/mfr/__init__.py b/magic_pdf/model/sub_modules/mfr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py b/magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py new file mode 100644 index 00000000..30b21ef8 --- /dev/null +++ b/magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py @@ -0,0 +1,98 @@ +import os +import argparse +import re + +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from unimernet.common.config import Config +import unimernet.tasks as tasks +from unimernet.processors import load_processor + + +class MathDataset(Dataset): + def __init__(self, image_paths, transform=None): + self.image_paths = image_paths + self.transform = transform + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + # if not pil image, then convert to pil image + if isinstance(self.image_paths[idx], str): + raw_image = Image.open(self.image_paths[idx]) + else: + raw_image = self.image_paths[idx] + if self.transform: + image = self.transform(raw_image) + return image + + +def latex_rm_whitespace(s: str): + """Remove unnecessary whitespace from LaTeX code. + """ + text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})' + letter = '[a-zA-Z]' + noletter = '[\W_^\d]' + names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)] + s = re.sub(text_reg, lambda match: str(names.pop(0)), s) + news = s + while True: + s = news + news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s) + news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news) + news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news) + if news == s: + break + return s + + +class UnimernetModel(object): + def __init__(self, weight_dir, cfg_path, _device_='cpu'): + + args = argparse.Namespace(cfg_path=cfg_path, options=None) + cfg = Config(args) + cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth") + cfg.config.model.model_config.model_name = weight_dir + cfg.config.model.tokenizer_config.path = weight_dir + task = tasks.setup_task(cfg) + self.model = task.build_model(cfg) + self.device = _device_ + self.model.to(_device_) + self.model.eval() + vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) + self.mfr_transform = transforms.Compose([vis_processor, ]) + + def predict(self, mfd_res, image): + + formula_list = [] + mf_image_list = [] + for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): + xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] + new_item = { + 'category_id': 13 + int(cla.item()), + 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], + 'score': round(float(conf.item()), 2), + 'latex': '', + } + formula_list.append(new_item) + pil_img = Image.fromarray(image) + bbox_img = pil_img.crop((xmin, ymin, xmax, ymax)) + mf_image_list.append(bbox_img) + + dataset = MathDataset(mf_image_list, transform=self.mfr_transform) + dataloader = DataLoader(dataset, batch_size=64, num_workers=0) + mfr_res = [] + for mf_img in dataloader: + mf_img = mf_img.to(self.device) + with torch.no_grad(): + output = self.model.generate({'image': mf_img}) + mfr_res.extend(output['pred_str']) + for res, latex in zip(formula_list, mfr_res): + res['latex'] = latex_rm_whitespace(latex) + return formula_list + + + diff --git a/magic_pdf/model/sub_modules/mfr/unimernet/__init__.py b/magic_pdf/model/sub_modules/mfr/unimernet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/sub_modules/model_init.py b/magic_pdf/model/sub_modules/model_init.py new file mode 100644 index 00000000..7c758518 --- /dev/null +++ b/magic_pdf/model/sub_modules/model_init.py @@ -0,0 +1,144 @@ +from loguru import logger + +from magic_pdf.libs.Constants import MODEL_NAME +from magic_pdf.model.model_list import AtomicModel +from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel +from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor +from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel + +from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel +from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR +# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR +from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel +from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel +from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel + + +def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): + if table_model_type == MODEL_NAME.STRUCT_EQTABLE: + table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) + elif table_model_type == MODEL_NAME.TABLE_MASTER: + config = { + "model_dir": model_path, + "device": _device_ + } + table_model = TableMasterPaddleModel(config) + elif table_model_type == MODEL_NAME.RAPID_TABLE: + table_model = RapidTableModel() + else: + logger.error("table model type not allow") + exit(1) + + return table_model + + +def mfd_model_init(weight, device='cpu'): + mfd_model = YOLOv8MFDModel(weight, device) + return mfd_model + + +def mfr_model_init(weight_dir, cfg_path, device='cpu'): + mfr_model = UnimernetModel(weight_dir, cfg_path, device) + return mfr_model + + +def layout_model_init(weight, config_file, device): + model = Layoutlmv3_Predictor(weight, config_file, device) + return model + + +def doclayout_yolo_model_init(weight, device='cpu'): + model = DocLayoutYOLOModel(weight, device) + return model + + +def ocr_model_init(show_log: bool = False, + det_db_box_thresh=0.3, + lang=None, + use_dilation=True, + det_db_unclip_ratio=1.8, + ): + if lang is not None: + model = ModifiedPaddleOCR( + show_log=show_log, + det_db_box_thresh=det_db_box_thresh, + lang=lang, + use_dilation=use_dilation, + det_db_unclip_ratio=det_db_unclip_ratio, + ) + else: + model = ModifiedPaddleOCR( + show_log=show_log, + det_db_box_thresh=det_db_box_thresh, + use_dilation=use_dilation, + det_db_unclip_ratio=det_db_unclip_ratio, + # use_angle_cls=True, + ) + return model + + +class AtomModelSingleton: + _instance = None + _models = {} + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def get_atom_model(self, atom_model_name: str, **kwargs): + lang = kwargs.get("lang", None) + layout_model_name = kwargs.get("layout_model_name", None) + key = (atom_model_name, layout_model_name, lang) + if key not in self._models: + self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) + return self._models[key] + + +def atom_model_init(model_name: str, **kwargs): + atom_model = None + if model_name == AtomicModel.Layout: + if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3: + atom_model = layout_model_init( + kwargs.get("layout_weights"), + kwargs.get("layout_config_file"), + kwargs.get("device") + ) + elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO: + atom_model = doclayout_yolo_model_init( + kwargs.get("doclayout_yolo_weights"), + kwargs.get("device") + ) + elif model_name == AtomicModel.MFD: + atom_model = mfd_model_init( + kwargs.get("mfd_weights"), + kwargs.get("device") + ) + elif model_name == AtomicModel.MFR: + atom_model = mfr_model_init( + kwargs.get("mfr_weight_dir"), + kwargs.get("mfr_cfg_path"), + kwargs.get("device") + ) + elif model_name == AtomicModel.OCR: + atom_model = ocr_model_init( + kwargs.get("ocr_show_log"), + kwargs.get("det_db_box_thresh"), + kwargs.get("lang") + ) + elif model_name == AtomicModel.Table: + atom_model = table_model_init( + kwargs.get("table_model_name"), + kwargs.get("table_model_path"), + kwargs.get("table_max_time"), + kwargs.get("device") + ) + else: + logger.error("model name not allow") + exit(1) + + if atom_model is None: + logger.error("model init failed") + exit(1) + else: + return atom_model diff --git a/magic_pdf/model/sub_modules/model_utils.py b/magic_pdf/model/sub_modules/model_utils.py new file mode 100644 index 00000000..55114679 --- /dev/null +++ b/magic_pdf/model/sub_modules/model_utils.py @@ -0,0 +1,51 @@ +import time + +import torch +from PIL import Image +from loguru import logger + +from magic_pdf.libs.clean_memory import clean_memory + + +def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0): + crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) + crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) + # Create a white background with an additional width and height of 50 + crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2 + crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2 + return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white') + + # Crop image + crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax) + cropped_img = input_pil_img.crop(crop_box) + return_image.paste(cropped_img, (crop_paste_x, crop_paste_y)) + return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height] + return return_image, return_list + + +# Select regions for OCR / formula regions / table regions +def get_res_list_from_layout_res(layout_res): + ocr_res_list = [] + table_res_list = [] + single_page_mfdetrec_res = [] + for res in layout_res: + if int(res['category_id']) in [13, 14]: + single_page_mfdetrec_res.append({ + "bbox": [int(res['poly'][0]), int(res['poly'][1]), + int(res['poly'][4]), int(res['poly'][5])], + }) + elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]: + ocr_res_list.append(res) + elif int(res['category_id']) in [5]: + table_res_list.append(res) + return ocr_res_list, table_res_list, single_page_mfdetrec_res + + +def clean_vram(device, vram_threshold=8): + if torch.cuda.is_available() and device != 'cpu': + total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB + if total_memory <= vram_threshold: + gc_start = time.time() + clean_memory() + gc_time = round(time.time() - gc_start, 2) + logger.info(f"gc time: {gc_time}") \ No newline at end of file diff --git a/magic_pdf/model/sub_modules/ocr/__init__.py b/magic_pdf/model/sub_modules/ocr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py b/magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py b/magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py new file mode 100644 index 00000000..f81ebbcf --- /dev/null +++ b/magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py @@ -0,0 +1,259 @@ +import math + +import numpy as np +from loguru import logger + +from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold +from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line + + +def bbox_to_points(bbox): + """ 将bbox格式转换为四个顶点的数组 """ + x0, y0, x1, y1 = bbox + return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32') + + +def points_to_bbox(points): + """ 将四个顶点的数组转换为bbox格式 """ + x0, y0 = points[0] + x1, _ = points[1] + _, y1 = points[2] + return [x0, y0, x1, y1] + + +def merge_intervals(intervals): + # Sort the intervals based on the start value + intervals.sort(key=lambda x: x[0]) + + merged = [] + for interval in intervals: + # If the list of merged intervals is empty or if the current + # interval does not overlap with the previous, simply append it. + if not merged or merged[-1][1] < interval[0]: + merged.append(interval) + else: + # Otherwise, there is overlap, so we merge the current and previous intervals. + merged[-1][1] = max(merged[-1][1], interval[1]) + + return merged + + +def remove_intervals(original, masks): + # Merge all mask intervals + merged_masks = merge_intervals(masks) + + result = [] + original_start, original_end = original + + for mask in merged_masks: + mask_start, mask_end = mask + + # If the mask starts after the original range, ignore it + if mask_start > original_end: + continue + + # If the mask ends before the original range starts, ignore it + if mask_end < original_start: + continue + + # Remove the masked part from the original range + if original_start < mask_start: + result.append([original_start, mask_start - 1]) + + original_start = max(mask_end + 1, original_start) + + # Add the remaining part of the original range, if any + if original_start <= original_end: + result.append([original_start, original_end]) + + return result + + +def update_det_boxes(dt_boxes, mfd_res): + new_dt_boxes = [] + for text_box in dt_boxes: + text_bbox = points_to_bbox(text_box) + masks_list = [] + for mf_box in mfd_res: + mf_bbox = mf_box['bbox'] + if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox): + masks_list.append([mf_bbox[0], mf_bbox[2]]) + text_x_range = [text_bbox[0], text_bbox[2]] + text_remove_mask_range = remove_intervals(text_x_range, masks_list) + temp_dt_box = [] + for text_remove_mask in text_remove_mask_range: + temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]])) + if len(temp_dt_box) > 0: + new_dt_boxes.extend(temp_dt_box) + return new_dt_boxes + + +def merge_overlapping_spans(spans): + """ + Merges overlapping spans on the same line. + + :param spans: A list of span coordinates [(x1, y1, x2, y2), ...] + :return: A list of merged spans + """ + # Return an empty list if the input spans list is empty + if not spans: + return [] + + # Sort spans by their starting x-coordinate + spans.sort(key=lambda x: x[0]) + + # Initialize the list of merged spans + merged = [] + for span in spans: + # Unpack span coordinates + x1, y1, x2, y2 = span + # If the merged list is empty or there's no horizontal overlap, add the span directly + if not merged or merged[-1][2] < x1: + merged.append(span) + else: + # If there is horizontal overlap, merge the current span with the previous one + last_span = merged.pop() + # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2) + x1 = min(last_span[0], x1) + y1 = min(last_span[1], y1) + x2 = max(last_span[2], x2) + y2 = max(last_span[3], y2) + # Add the merged span back to the list + merged.append((x1, y1, x2, y2)) + + # Return the list of merged spans + return merged + + +def merge_det_boxes(dt_boxes): + """ + Merge detection boxes. + + This function takes a list of detected bounding boxes, each represented by four corner points. + The goal is to merge these bounding boxes into larger text regions. + + Parameters: + dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points. + + Returns: + list: A list containing the merged text regions, where each region is represented by four corner points. + """ + # Convert the detection boxes into a dictionary format with bounding boxes and type + dt_boxes_dict_list = [] + angle_boxes_list = [] + for text_box in dt_boxes: + text_bbox = points_to_bbox(text_box) + if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]: + angle_boxes_list.append(text_box) + continue + text_box_dict = { + 'bbox': text_bbox, + 'type': 'text', + } + dt_boxes_dict_list.append(text_box_dict) + + # Merge adjacent text regions into lines + lines = merge_spans_to_line(dt_boxes_dict_list) + + # Initialize a new list for storing the merged text regions + new_dt_boxes = [] + for line in lines: + line_bbox_list = [] + for span in line: + line_bbox_list.append(span['bbox']) + + # Merge overlapping text regions within the same line + merged_spans = merge_overlapping_spans(line_bbox_list) + + # Convert the merged text regions back to point format and add them to the new detection box list + for span in merged_spans: + new_dt_boxes.append(bbox_to_points(span)) + + new_dt_boxes.extend(angle_boxes_list) + + return new_dt_boxes + + +def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list): + paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list + # Adjust the coordinates of the formula area + adjusted_mfdetrec_res = [] + for mf_res in single_page_mfdetrec_res: + mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] + # Adjust the coordinates of the formula area to the coordinates relative to the cropping area + x0 = mf_xmin - xmin + paste_x + y0 = mf_ymin - ymin + paste_y + x1 = mf_xmax - xmin + paste_x + y1 = mf_ymax - ymin + paste_y + # Filter formula blocks outside the graph + if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): + continue + else: + adjusted_mfdetrec_res.append({ + "bbox": [x0, y0, x1, y1], + }) + return adjusted_mfdetrec_res + + +def get_ocr_result_list(ocr_res, useful_list): + paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list + ocr_result_list = [] + for box_ocr_res in ocr_res: + + p1, p2, p3, p4 = box_ocr_res[0] + text, score = box_ocr_res[1] + average_angle_degrees = calculate_angle_degrees(box_ocr_res[0]) + if average_angle_degrees > 0.5: + # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}") + # 与x轴的夹角超过0.5度,对边界做一下矫正 + # 计算几何中心 + x_center = sum(point[0] for point in box_ocr_res[0]) / 4 + y_center = sum(point[1] for point in box_ocr_res[0]) / 4 + new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2 + new_width = p3[0] - p1[0] + p1 = [x_center - new_width / 2, y_center - new_height / 2] + p2 = [x_center + new_width / 2, y_center - new_height / 2] + p3 = [x_center + new_width / 2, y_center + new_height / 2] + p4 = [x_center - new_width / 2, y_center + new_height / 2] + + # Convert the coordinates back to the original coordinate system + p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] + p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] + p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] + p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] + + ocr_result_list.append({ + 'category_id': 15, + 'poly': p1 + p2 + p3 + p4, + 'score': float(round(score, 2)), + 'text': text, + }) + + return ocr_result_list + + +def calculate_angle_degrees(poly): + # 定义对角线的顶点 + diagonal1 = (poly[0], poly[2]) + diagonal2 = (poly[1], poly[3]) + + # 计算对角线的斜率 + def slope(p1, p2): + return (p2[1] - p1[1]) / (p2[0] - p1[0]) if p2[0] != p1[0] else float('inf') + + slope1 = slope(diagonal1[0], diagonal1[1]) + slope2 = slope(diagonal2[0], diagonal2[1]) + + # 计算对角线与x轴的夹角(以弧度为单位) + angle1_radians = math.atan(slope1) + angle2_radians = math.atan(slope2) + + # 将弧度转换为角度 + angle1_degrees = math.degrees(angle1_radians) + angle2_degrees = math.degrees(angle2_radians) + + # 取两条对角线与x轴夹角的平均值 + average_angle_degrees = abs((angle1_degrees + angle2_degrees) / 2) + # logger.info(f"average_angle_degrees: {average_angle_degrees}") + return average_angle_degrees + diff --git a/magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py b/magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py new file mode 100644 index 00000000..226846ba --- /dev/null +++ b/magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py @@ -0,0 +1,168 @@ +import copy +import time + +import cv2 +import numpy as np +from paddleocr import PaddleOCR +from paddleocr.paddleocr import check_img, logger +from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img +from paddleocr.tools.infer.predict_system import sorted_boxes +from paddleocr.tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop + +from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes + + +class ModifiedPaddleOCR(PaddleOCR): + def ocr(self, + img, + det=True, + rec=True, + cls=True, + bin=False, + inv=False, + alpha_color=(255, 255, 255), + mfd_res=None, + ): + """ + OCR with PaddleOCR + args: + img: img for OCR, support ndarray, img_path and list or ndarray + det: use text detection or not. If False, only rec will be exec. Default is True + rec: use text recognition or not. If False, only det will be exec. Default is True + cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False. + bin: binarize image to black and white. Default is False. + inv: invert image colors. Default is False. + alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white. + """ + assert isinstance(img, (np.ndarray, list, str, bytes)) + if isinstance(img, list) and det == True: + logger.error('When input a list of images, det must be false') + exit(0) + if cls == True and self.use_angle_cls == False: + pass + # logger.warning( + # 'Since the angle classifier is not initialized, it will not be used during the forward process' + # ) + + img = check_img(img) + # for infer pdf file + if isinstance(img, list): + if self.page_num > len(img) or self.page_num == 0: + self.page_num = len(img) + imgs = img[:self.page_num] + else: + imgs = [img] + + def preprocess_image(_image): + _image = alpha_to_color(_image, alpha_color) + if inv: + _image = cv2.bitwise_not(_image) + if bin: + _image = binarize_img(_image) + return _image + + if det and rec: + ocr_res = [] + for idx, img in enumerate(imgs): + img = preprocess_image(img) + dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res) + if not dt_boxes and not rec_res: + ocr_res.append(None) + continue + tmp_res = [[box.tolist(), res] + for box, res in zip(dt_boxes, rec_res)] + ocr_res.append(tmp_res) + return ocr_res + elif det and not rec: + ocr_res = [] + for idx, img in enumerate(imgs): + img = preprocess_image(img) + dt_boxes, elapse = self.text_detector(img) + if not dt_boxes: + ocr_res.append(None) + continue + tmp_res = [box.tolist() for box in dt_boxes] + ocr_res.append(tmp_res) + return ocr_res + else: + ocr_res = [] + cls_res = [] + for idx, img in enumerate(imgs): + if not isinstance(img, list): + img = preprocess_image(img) + img = [img] + if self.use_angle_cls and cls: + img, cls_res_tmp, elapse = self.text_classifier(img) + if not rec: + cls_res.append(cls_res_tmp) + rec_res, elapse = self.text_recognizer(img) + ocr_res.append(rec_res) + if not rec: + return cls_res + return ocr_res + + def __call__(self, img, cls=True, mfd_res=None): + time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} + + if img is None: + logger.debug("no valid image provided") + return None, None, time_dict + + start = time.time() + ori_im = img.copy() + dt_boxes, elapse = self.text_detector(img) + time_dict['det'] = elapse + + if dt_boxes is None: + logger.debug("no dt_boxes found, elapsed : {}".format(elapse)) + end = time.time() + time_dict['all'] = end - start + return None, None, time_dict + else: + logger.debug("dt_boxes num : {}, elapsed : {}".format( + len(dt_boxes), elapse)) + img_crop_list = [] + + dt_boxes = sorted_boxes(dt_boxes) + + # @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge + # dt_boxes = merge_det_boxes(dt_boxes) + + + if mfd_res: + bef = time.time() + dt_boxes = update_det_boxes(dt_boxes, mfd_res) + aft = time.time() + logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format( + len(dt_boxes), aft - bef)) + + for bno in range(len(dt_boxes)): + tmp_box = copy.deepcopy(dt_boxes[bno]) + if self.args.det_box_type == "quad": + img_crop = get_rotate_crop_image(ori_im, tmp_box) + else: + img_crop = get_minarea_rect_crop(ori_im, tmp_box) + img_crop_list.append(img_crop) + if self.use_angle_cls and cls: + img_crop_list, angle_list, elapse = self.text_classifier( + img_crop_list) + time_dict['cls'] = elapse + logger.debug("cls num : {}, elapsed : {}".format( + len(img_crop_list), elapse)) + + rec_res, elapse = self.text_recognizer(img_crop_list) + time_dict['rec'] = elapse + logger.debug("rec_res num : {}, elapsed : {}".format( + len(rec_res), elapse)) + if self.args.save_crop_res: + self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, + rec_res) + filter_boxes, filter_rec_res = [], [] + for box, rec_result in zip(dt_boxes, rec_res): + text, score = rec_result + if score >= self.drop_score: + filter_boxes.append(box) + filter_rec_res.append(rec_result) + end = time.time() + time_dict['all'] = end - start + return filter_boxes, filter_rec_res, time_dict \ No newline at end of file diff --git a/magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py b/magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py new file mode 100644 index 00000000..8da5e430 --- /dev/null +++ b/magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py @@ -0,0 +1,213 @@ +import copy +import time + + +import cv2 +import numpy as np +from paddleocr import PaddleOCR +from paddleocr.paddleocr import check_img, logger +from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img +from paddleocr.tools.infer.predict_system import sorted_boxes +from paddleocr.tools.infer.utility import slice_generator, merge_fragmented, get_rotate_crop_image, \ + get_minarea_rect_crop + +from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes + + +class ModifiedPaddleOCR(PaddleOCR): + + def ocr( + self, + img, + det=True, + rec=True, + cls=True, + bin=False, + inv=False, + alpha_color=(255, 255, 255), + slice={}, + mfd_res=None, + ): + """ + OCR with PaddleOCR + + Args: + img: Image for OCR. It can be an ndarray, img_path, or a list of ndarrays. + det: Use text detection or not. If False, only text recognition will be executed. Default is True. + rec: Use text recognition or not. If False, only text detection will be executed. Default is True. + cls: Use angle classifier or not. Default is True. If True, the text with a rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. + bin: Binarize image to black and white. Default is False. + inv: Invert image colors. Default is False. + alpha_color: Set RGB color Tuple for transparent parts replacement. Default is pure white. + slice: Use sliding window inference for large images. Both det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres"] (See doc/doc_en/slice_en.md). Default is {}. + + Returns: + If both det and rec are True, returns a list of OCR results for each image. Each OCR result is a list of bounding boxes and recognized text for each detected text region. + If det is True and rec is False, returns a list of detected bounding boxes for each image. + If det is False and rec is True, returns a list of recognized text for each image. + If both det and rec are False, returns a list of angle classification results for each image. + + Raises: + AssertionError: If the input image is not of type ndarray, list, str, or bytes. + SystemExit: If det is True and the input is a list of images. + + Note: + - If the angle classifier is not initialized (use_angle_cls=False), it will not be used during the forward process. + - For PDF files, if the input is a list of images and the page_num is specified, only the first page_num images will be processed. + - The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified. + """ + assert isinstance(img, (np.ndarray, list, str, bytes)) + if isinstance(img, list) and det == True: + logger.error("When input a list of images, det must be false") + exit(0) + if cls == True and self.use_angle_cls == False: + logger.warning( + "Since the angle classifier is not initialized, it will not be used during the forward process" + ) + + img, flag_gif, flag_pdf = check_img(img, alpha_color) + # for infer pdf file + if isinstance(img, list) and flag_pdf: + if self.page_num > len(img) or self.page_num == 0: + imgs = img + else: + imgs = img[: self.page_num] + else: + imgs = [img] + + def preprocess_image(_image): + _image = alpha_to_color(_image, alpha_color) + if inv: + _image = cv2.bitwise_not(_image) + if bin: + _image = binarize_img(_image) + return _image + + if det and rec: + ocr_res = [] + for img in imgs: + img = preprocess_image(img) + dt_boxes, rec_res, _ = self.__call__(img, cls, slice, mfd_res=mfd_res) + if not dt_boxes and not rec_res: + ocr_res.append(None) + continue + tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] + ocr_res.append(tmp_res) + return ocr_res + elif det and not rec: + ocr_res = [] + for img in imgs: + img = preprocess_image(img) + dt_boxes, elapse = self.text_detector(img) + if dt_boxes.size == 0: + ocr_res.append(None) + continue + tmp_res = [box.tolist() for box in dt_boxes] + ocr_res.append(tmp_res) + return ocr_res + else: + ocr_res = [] + cls_res = [] + for img in imgs: + if not isinstance(img, list): + img = preprocess_image(img) + img = [img] + if self.use_angle_cls and cls: + img, cls_res_tmp, elapse = self.text_classifier(img) + if not rec: + cls_res.append(cls_res_tmp) + rec_res, elapse = self.text_recognizer(img) + ocr_res.append(rec_res) + if not rec: + return cls_res + return ocr_res + + def __call__(self, img, cls=True, slice={}, mfd_res=None): + time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0} + + if img is None: + logger.debug("no valid image provided") + return None, None, time_dict + + start = time.time() + ori_im = img.copy() + if slice: + slice_gen = slice_generator( + img, + horizontal_stride=slice["horizontal_stride"], + vertical_stride=slice["vertical_stride"], + ) + elapsed = [] + dt_slice_boxes = [] + for slice_crop, v_start, h_start in slice_gen: + dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True) + if dt_boxes.size: + dt_boxes[:, :, 0] += h_start + dt_boxes[:, :, 1] += v_start + dt_slice_boxes.append(dt_boxes) + elapsed.append(elapse) + dt_boxes = np.concatenate(dt_slice_boxes) + + dt_boxes = merge_fragmented( + boxes=dt_boxes, + x_threshold=slice["merge_x_thres"], + y_threshold=slice["merge_y_thres"], + ) + elapse = sum(elapsed) + else: + dt_boxes, elapse = self.text_detector(img) + + time_dict["det"] = elapse + + if dt_boxes is None: + logger.debug("no dt_boxes found, elapsed : {}".format(elapse)) + end = time.time() + time_dict["all"] = end - start + return None, None, time_dict + else: + logger.debug( + "dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse) + ) + img_crop_list = [] + + dt_boxes = sorted_boxes(dt_boxes) + + if mfd_res: + bef = time.time() + dt_boxes = update_det_boxes(dt_boxes, mfd_res) + aft = time.time() + logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format( + len(dt_boxes), aft - bef)) + + for bno in range(len(dt_boxes)): + tmp_box = copy.deepcopy(dt_boxes[bno]) + if self.args.det_box_type == "quad": + img_crop = get_rotate_crop_image(ori_im, tmp_box) + else: + img_crop = get_minarea_rect_crop(ori_im, tmp_box) + img_crop_list.append(img_crop) + if self.use_angle_cls and cls: + img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list) + time_dict["cls"] = elapse + logger.debug( + "cls num : {}, elapsed : {}".format(len(img_crop_list), elapse) + ) + if len(img_crop_list) > 1000: + logger.debug( + f"rec crops num: {len(img_crop_list)}, time and memory cost may be large." + ) + + rec_res, elapse = self.text_recognizer(img_crop_list) + time_dict["rec"] = elapse + logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse)) + if self.args.save_crop_res: + self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res) + filter_boxes, filter_rec_res = [], [] + for box, rec_result in zip(dt_boxes, rec_res): + text, score = rec_result[0], rec_result[1] + if score >= self.drop_score: + filter_boxes.append(box) + filter_rec_res.append(rec_result) + end = time.time() + time_dict["all"] = end - start + return filter_boxes, filter_rec_res, time_dict diff --git a/magic_pdf/model/sub_modules/reading_oreder/__init__.py b/magic_pdf/model/sub_modules/reading_oreder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py b/magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/v3/helpers.py b/magic_pdf/model/sub_modules/reading_oreder/layoutreader/helpers.py similarity index 100% rename from magic_pdf/model/v3/helpers.py rename to magic_pdf/model/sub_modules/reading_oreder/layoutreader/helpers.py diff --git a/magic_pdf/model/v3/xycut.py b/magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py similarity index 100% rename from magic_pdf/model/v3/xycut.py rename to magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py diff --git a/magic_pdf/model/sub_modules/table/__init__.py b/magic_pdf/model/sub_modules/table/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/sub_modules/table/rapidtable/__init__.py b/magic_pdf/model/sub_modules/table/rapidtable/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py b/magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py new file mode 100644 index 00000000..221e55d9 --- /dev/null +++ b/magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py @@ -0,0 +1,14 @@ +import numpy as np +from rapid_table import RapidTable +from rapidocr_paddle import RapidOCR + + +class RapidTableModel(object): + def __init__(self): + self.table_model = RapidTable() + self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True) + + def predict(self, image): + ocr_result, _ = self.ocr_engine(np.asarray(image)) + html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result) + return html_code, table_cell_bboxes, elapse \ No newline at end of file diff --git a/magic_pdf/model/sub_modules/table/structeqtable/__init__.py b/magic_pdf/model/sub_modules/table/structeqtable/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py b/magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py similarity index 72% rename from magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py rename to magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py index 1626d437..7fb14d0b 100644 --- a/magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +++ b/magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py @@ -1,8 +1,8 @@ -import re - import torch from struct_eqtable import build_model +from magic_pdf.model.sub_modules.table.table_utils import minify_html + class StructTableModel: def __init__(self, model_path, max_new_tokens=1024, max_time=60): @@ -31,15 +31,7 @@ class StructTableModel: ) if output_format == "html": - results = [self.minify_html(html) for html in results] + results = [minify_html(html) for html in results] return results - def minify_html(self, html): - # 移除多余的空白字符 - html = re.sub(r'\s+', ' ', html) - # 移除行尾的空白字符 - html = re.sub(r'\s*>\s*', '>', html) - # 移除标签前的空白字符 - html = re.sub(r'\s*<\s*', '<', html) - return html.strip() \ No newline at end of file diff --git a/magic_pdf/model/sub_modules/table/table_utils.py b/magic_pdf/model/sub_modules/table/table_utils.py new file mode 100644 index 00000000..f04bf98d --- /dev/null +++ b/magic_pdf/model/sub_modules/table/table_utils.py @@ -0,0 +1,11 @@ +import re + + +def minify_html(html): + # 移除多余的空白字符 + html = re.sub(r'\s+', ' ', html) + # 移除行尾的空白字符 + html = re.sub(r'\s*>\s*', '>', html) + # 移除标签前的空白字符 + html = re.sub(r'\s*<\s*', '<', html) + return html.strip() \ No newline at end of file diff --git a/magic_pdf/model/sub_modules/table/tablemaster/__init__.py b/magic_pdf/model/sub_modules/table/tablemaster/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/magic_pdf/model/ppTableModel.py b/magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py similarity index 98% rename from magic_pdf/model/ppTableModel.py rename to magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py index 3f08d78d..492f2b9a 100644 --- a/magic_pdf/model/ppTableModel.py +++ b/magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py @@ -7,7 +7,7 @@ from PIL import Image import numpy as np -class ppTableModel(object): +class TableMasterPaddleModel(object): """ This class is responsible for converting image of table into HTML format using a pre-trained model. diff --git a/magic_pdf/pdf_parse_union_core_v2.py b/magic_pdf/pdf_parse_union_core_v2.py index 0cd1ed04..060bdb7c 100644 --- a/magic_pdf/pdf_parse_union_core_v2.py +++ b/magic_pdf/pdf_parse_union_core_v2.py @@ -164,8 +164,8 @@ class ModelSingleton: def do_predict(boxes: List[List[int]], model) -> List[int]: - from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits, - prepare_inputs) + from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (boxes2inputs, parse_logits, + prepare_inputs) inputs = boxes2inputs(boxes) inputs = prepare_inputs(inputs, model) @@ -206,7 +206,7 @@ def cal_block_index(fix_blocks, sorted_bboxes): del block['real_lines'] import numpy as np - from magic_pdf.model.v3.xycut import recursive_xy_cut + from magic_pdf.model.sub_modules.reading_oreder.layoutreader.xycut import recursive_xy_cut random_boxes = np.array(block_bboxes) np.random.shuffle(random_boxes) diff --git a/setup.py b/setup.py index 80750d07..fc1f9b39 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ if __name__ == '__main__': "doclayout_yolo==0.0.2", # doclayout_yolo "rapidocr-paddle", # rapidocr-paddle "rapid_table", # rapid_table + "PyYAML", # yaml "detectron2" ], }, diff --git a/tests/test_table/test_tablemaster.py b/tests/test_table/test_tablemaster.py index f314feb1..f494355c 100644 --- a/tests/test_table/test_tablemaster.py +++ b/tests/test_table/test_tablemaster.py @@ -2,7 +2,7 @@ import unittest from PIL import Image from lxml import etree -from magic_pdf.model.ppTableModel import ppTableModel +from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel class TestppTableModel(unittest.TestCase): @@ -11,7 +11,7 @@ class TestppTableModel(unittest.TestCase): # 修改table模型路径 config = {"device": "cuda", "model_dir": "/home/quyuan/.cache/modelscope/hub/opendatalab/PDF-Extract-Kit/models/TabRec/TableMaster"} - table_model = ppTableModel(config) + table_model = TableMasterPaddleModel(config) res = table_model.img2html(img) # 验证生成的 HTML 是否符合预期 parser = etree.HTMLParser()