mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
Merge pull request #964 from myhloli/dev
refactor(model): rename and restructure model modules
This commit is contained in:
@@ -1,203 +1,28 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
import os
|
||||
import time
|
||||
from magic_pdf.libs.Constants import *
|
||||
from magic_pdf.libs.clean_memory import clean_memory
|
||||
from magic_pdf.model.model_list import AtomicModel
|
||||
import cv2
|
||||
import yaml
|
||||
from PIL import Image
|
||||
|
||||
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
||||
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
||||
|
||||
try:
|
||||
import cv2
|
||||
import yaml
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchtext
|
||||
|
||||
if torchtext.__version__ >= "0.18.0":
|
||||
torchtext.disable_torchtext_deprecation_warning()
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from ultralytics import YOLO
|
||||
from unimernet.common.config import Config
|
||||
import unimernet.tasks as tasks
|
||||
from unimernet.processors import load_processor
|
||||
from doclayout_yolo import YOLOv10
|
||||
from rapid_table import RapidTable
|
||||
from rapidocr_paddle import RapidOCR
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
except ImportError as e:
|
||||
logger.exception(e)
|
||||
logger.error(
|
||||
'Required dependency not installed, please install by \n'
|
||||
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
|
||||
exit(1)
|
||||
|
||||
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
||||
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
|
||||
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
||||
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
||||
from magic_pdf.model.ppTableModel import ppTableModel
|
||||
|
||||
|
||||
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
||||
ocr_engine = None
|
||||
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
||||
table_model = StructTableModel(model_path, max_time=max_time)
|
||||
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
||||
config = {
|
||||
"model_dir": model_path,
|
||||
"device": _device_
|
||||
}
|
||||
table_model = ppTableModel(config)
|
||||
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
||||
table_model = RapidTable()
|
||||
ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
||||
else:
|
||||
logger.error("table model type not allow")
|
||||
exit(1)
|
||||
|
||||
if ocr_engine:
|
||||
return [table_model, ocr_engine]
|
||||
else:
|
||||
return table_model
|
||||
|
||||
|
||||
def mfd_model_init(weight):
|
||||
mfd_model = YOLO(weight)
|
||||
return mfd_model
|
||||
|
||||
|
||||
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
||||
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
||||
cfg = Config(args)
|
||||
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
||||
cfg.config.model.model_config.model_name = weight_dir
|
||||
cfg.config.model.tokenizer_config.path = weight_dir
|
||||
task = tasks.setup_task(cfg)
|
||||
model = task.build_model(cfg)
|
||||
model.to(_device_)
|
||||
model.eval()
|
||||
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
||||
mfr_transform = transforms.Compose([vis_processor, ])
|
||||
return [model, mfr_transform]
|
||||
|
||||
|
||||
def layout_model_init(weight, config_file, device):
|
||||
model = Layoutlmv3_Predictor(weight, config_file, device)
|
||||
return model
|
||||
|
||||
|
||||
def doclayout_yolo_model_init(weight):
|
||||
model = YOLOv10(weight)
|
||||
return model
|
||||
|
||||
|
||||
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
|
||||
if lang is not None:
|
||||
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
||||
else:
|
||||
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
||||
return model
|
||||
|
||||
|
||||
class MathDataset(Dataset):
|
||||
def __init__(self, image_paths, transform=None):
|
||||
self.image_paths = image_paths
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# if not pil image, then convert to pil image
|
||||
if isinstance(self.image_paths[idx], str):
|
||||
raw_image = Image.open(self.image_paths[idx])
|
||||
else:
|
||||
raw_image = self.image_paths[idx]
|
||||
if self.transform:
|
||||
image = self.transform(raw_image)
|
||||
return image
|
||||
|
||||
|
||||
class AtomModelSingleton:
|
||||
_instance = None
|
||||
_models = {}
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def get_atom_model(self, atom_model_name: str, **kwargs):
|
||||
lang = kwargs.get("lang", None)
|
||||
layout_model_name = kwargs.get("layout_model_name", None)
|
||||
key = (atom_model_name, layout_model_name, lang)
|
||||
if key not in self._models:
|
||||
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
||||
return self._models[key]
|
||||
|
||||
|
||||
def atom_model_init(model_name: str, **kwargs):
|
||||
|
||||
if model_name == AtomicModel.Layout:
|
||||
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
|
||||
atom_model = layout_model_init(
|
||||
kwargs.get("layout_weights"),
|
||||
kwargs.get("layout_config_file"),
|
||||
kwargs.get("device")
|
||||
)
|
||||
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
|
||||
atom_model = doclayout_yolo_model_init(
|
||||
kwargs.get("doclayout_yolo_weights"),
|
||||
)
|
||||
elif model_name == AtomicModel.MFD:
|
||||
atom_model = mfd_model_init(
|
||||
kwargs.get("mfd_weights")
|
||||
)
|
||||
elif model_name == AtomicModel.MFR:
|
||||
atom_model = mfr_model_init(
|
||||
kwargs.get("mfr_weight_dir"),
|
||||
kwargs.get("mfr_cfg_path"),
|
||||
kwargs.get("device")
|
||||
)
|
||||
elif model_name == AtomicModel.OCR:
|
||||
atom_model = ocr_model_init(
|
||||
kwargs.get("ocr_show_log"),
|
||||
kwargs.get("det_db_box_thresh"),
|
||||
kwargs.get("lang")
|
||||
)
|
||||
elif model_name == AtomicModel.Table:
|
||||
atom_model = table_model_init(
|
||||
kwargs.get("table_model_name"),
|
||||
kwargs.get("table_model_path"),
|
||||
kwargs.get("table_max_time"),
|
||||
kwargs.get("device")
|
||||
)
|
||||
else:
|
||||
logger.error("model name not allow")
|
||||
exit(1)
|
||||
|
||||
return atom_model
|
||||
|
||||
|
||||
# Unified crop img logic
|
||||
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
||||
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
||||
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
||||
# Create a white background with an additional width and height of 50
|
||||
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
||||
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
||||
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
||||
|
||||
# Crop image
|
||||
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
||||
cropped_img = input_pil_img.crop(crop_box)
|
||||
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
||||
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
||||
return return_image, return_list
|
||||
from magic_pdf.libs.Constants import *
|
||||
from magic_pdf.model.model_list import AtomicModel
|
||||
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
||||
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
|
||||
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
|
||||
|
||||
|
||||
class CustomPEKModel:
|
||||
@@ -243,7 +68,8 @@ class CustomPEKModel:
|
||||
logger.info(
|
||||
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
|
||||
"apply_table: {}, table_model: {}, lang: {}".format(
|
||||
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
|
||||
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
|
||||
self.lang
|
||||
)
|
||||
)
|
||||
# 初始化解析方案
|
||||
@@ -256,17 +82,17 @@ class CustomPEKModel:
|
||||
|
||||
# 初始化公式识别
|
||||
if self.apply_formula:
|
||||
|
||||
# 初始化公式检测模型
|
||||
self.mfd_model = atom_model_manager.get_atom_model(
|
||||
atom_model_name=AtomicModel.MFD,
|
||||
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
|
||||
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# 初始化公式解析模型
|
||||
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
|
||||
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
|
||||
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
|
||||
self.mfr_model = atom_model_manager.get_atom_model(
|
||||
atom_model_name=AtomicModel.MFR,
|
||||
mfr_weight_dir=mfr_weight_dir,
|
||||
mfr_cfg_path=mfr_cfg_path,
|
||||
@@ -286,7 +112,8 @@ class CustomPEKModel:
|
||||
self.layout_model = atom_model_manager.get_atom_model(
|
||||
atom_model_name=AtomicModel.Layout,
|
||||
layout_model_name=MODEL_NAME.DocLayout_YOLO,
|
||||
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
|
||||
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
|
||||
device=self.device
|
||||
)
|
||||
# 初始化ocr
|
||||
if self.apply_ocr:
|
||||
@@ -299,22 +126,13 @@ class CustomPEKModel:
|
||||
# init table model
|
||||
if self.apply_table:
|
||||
table_model_dir = self.configs["weights"][self.table_model_name]
|
||||
if self.table_model_name in [MODEL_NAME.STRUCT_EQTABLE, MODEL_NAME.TABLE_MASTER]:
|
||||
self.table_model = atom_model_manager.get_atom_model(
|
||||
atom_model_name=AtomicModel.Table,
|
||||
table_model_name=self.table_model_name,
|
||||
table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
||||
table_max_time=self.table_max_time,
|
||||
device=self.device
|
||||
)
|
||||
elif self.table_model_name in [MODEL_NAME.RAPID_TABLE]:
|
||||
self.table_model, self.ocr_engine =atom_model_manager.get_atom_model(
|
||||
atom_model_name=AtomicModel.Table,
|
||||
table_model_name=self.table_model_name,
|
||||
table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
||||
table_max_time=self.table_max_time,
|
||||
device=self.device
|
||||
)
|
||||
self.table_model = atom_model_manager.get_atom_model(
|
||||
atom_model_name=AtomicModel.Table,
|
||||
table_model_name=self.table_model_name,
|
||||
table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
||||
table_max_time=self.table_max_time,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
logger.info('DocAnalysis init done!')
|
||||
|
||||
@@ -322,26 +140,15 @@ class CustomPEKModel:
|
||||
|
||||
page_start = time.time()
|
||||
|
||||
latex_filling_list = []
|
||||
mf_image_list = []
|
||||
|
||||
# layout检测
|
||||
layout_start = time.time()
|
||||
layout_res = []
|
||||
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
||||
# layoutlmv3
|
||||
layout_res = self.layout_model(image, ignore_catids=[])
|
||||
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
||||
# doclayout_yolo
|
||||
layout_res = []
|
||||
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
||||
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
|
||||
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
||||
new_item = {
|
||||
'category_id': int(cla.item()),
|
||||
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
||||
'score': round(float(conf.item()), 3),
|
||||
}
|
||||
layout_res.append(new_item)
|
||||
layout_res = self.layout_model.predict(image)
|
||||
layout_cost = round(time.time() - layout_start, 2)
|
||||
logger.info(f"layout detection time: {layout_cost}")
|
||||
|
||||
@@ -350,58 +157,21 @@ class CustomPEKModel:
|
||||
if self.apply_formula:
|
||||
# 公式检测
|
||||
mfd_start = time.time()
|
||||
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
||||
mfd_res = self.mfd_model.predict(image)
|
||||
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
|
||||
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
||||
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
||||
new_item = {
|
||||
'category_id': 13 + int(cla.item()),
|
||||
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
||||
'score': round(float(conf.item()), 2),
|
||||
'latex': '',
|
||||
}
|
||||
layout_res.append(new_item)
|
||||
latex_filling_list.append(new_item)
|
||||
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
||||
mf_image_list.append(bbox_img)
|
||||
|
||||
# 公式识别
|
||||
mfr_start = time.time()
|
||||
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
||||
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
||||
mfr_res = []
|
||||
for mf_img in dataloader:
|
||||
mf_img = mf_img.to(self.device)
|
||||
with torch.no_grad():
|
||||
output = self.mfr_model.generate({'image': mf_img})
|
||||
mfr_res.extend(output['pred_str'])
|
||||
for res, latex in zip(latex_filling_list, mfr_res):
|
||||
res['latex'] = latex_rm_whitespace(latex)
|
||||
formula_list = self.mfr_model.predict(mfd_res, image)
|
||||
layout_res.extend(formula_list)
|
||||
mfr_cost = round(time.time() - mfr_start, 2)
|
||||
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
||||
logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
|
||||
|
||||
# Select regions for OCR / formula regions / table regions
|
||||
ocr_res_list = []
|
||||
table_res_list = []
|
||||
single_page_mfdetrec_res = []
|
||||
for res in layout_res:
|
||||
if int(res['category_id']) in [13, 14]:
|
||||
single_page_mfdetrec_res.append({
|
||||
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
||||
int(res['poly'][4]), int(res['poly'][5])],
|
||||
})
|
||||
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
||||
ocr_res_list.append(res)
|
||||
elif int(res['category_id']) in [5]:
|
||||
table_res_list.append(res)
|
||||
# 清理显存
|
||||
clean_vram(self.device, vram_threshold=8)
|
||||
|
||||
if torch.cuda.is_available() and self.device != 'cpu':
|
||||
total_memory = torch.cuda.get_device_properties(self.device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
||||
if total_memory <= 8:
|
||||
gc_start = time.time()
|
||||
clean_memory()
|
||||
gc_time = round(time.time() - gc_start, 2)
|
||||
logger.info(f"gc time: {gc_time}")
|
||||
# 从layout_res中获取ocr区域、表格区域、公式区域
|
||||
ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
|
||||
|
||||
# ocr识别
|
||||
if self.apply_ocr:
|
||||
@@ -409,23 +179,7 @@ class CustomPEKModel:
|
||||
# Process each area that requires OCR processing
|
||||
for res in ocr_res_list:
|
||||
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
|
||||
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
||||
# Adjust the coordinates of the formula area
|
||||
adjusted_mfdetrec_res = []
|
||||
for mf_res in single_page_mfdetrec_res:
|
||||
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
||||
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
||||
x0 = mf_xmin - xmin + paste_x
|
||||
y0 = mf_ymin - ymin + paste_y
|
||||
x1 = mf_xmax - xmin + paste_x
|
||||
y1 = mf_ymax - ymin + paste_y
|
||||
# Filter formula blocks outside the graph
|
||||
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
||||
continue
|
||||
else:
|
||||
adjusted_mfdetrec_res.append({
|
||||
"bbox": [x0, y0, x1, y1],
|
||||
})
|
||||
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
|
||||
|
||||
# OCR recognition
|
||||
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
||||
@@ -433,22 +187,8 @@ class CustomPEKModel:
|
||||
|
||||
# Integration results
|
||||
if ocr_res:
|
||||
for box_ocr_res in ocr_res:
|
||||
p1, p2, p3, p4 = box_ocr_res[0]
|
||||
text, score = box_ocr_res[1]
|
||||
|
||||
# Convert the coordinates back to the original coordinate system
|
||||
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
||||
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
||||
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
||||
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
||||
|
||||
layout_res.append({
|
||||
'category_id': 15,
|
||||
'poly': p1 + p2 + p3 + p4,
|
||||
'score': round(score, 2),
|
||||
'text': text,
|
||||
})
|
||||
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
||||
layout_res.extend(ocr_result_list)
|
||||
|
||||
ocr_cost = round(time.time() - ocr_start, 2)
|
||||
logger.info(f"ocr time: {ocr_cost}")
|
||||
@@ -459,8 +199,6 @@ class CustomPEKModel:
|
||||
for res in table_res_list:
|
||||
new_image, _ = crop_img(res, pil_img)
|
||||
single_table_start_time = time.time()
|
||||
# logger.info("------------------table recognition processing begins-----------------")
|
||||
latex_code = None
|
||||
html_code = None
|
||||
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
||||
with torch.no_grad():
|
||||
@@ -470,33 +208,21 @@ class CustomPEKModel:
|
||||
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
||||
html_code = self.table_model.img2html(new_image)
|
||||
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
||||
ocr_result, _ = self.ocr_engine(np.asarray(new_image))
|
||||
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), ocr_result)
|
||||
|
||||
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
|
||||
run_time = time.time() - single_table_start_time
|
||||
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
||||
if run_time > self.table_max_time:
|
||||
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
|
||||
logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
|
||||
# 判断是否返回正常
|
||||
|
||||
if latex_code:
|
||||
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
|
||||
if expected_ending:
|
||||
res["latex"] = latex_code
|
||||
else:
|
||||
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
|
||||
elif html_code:
|
||||
if html_code:
|
||||
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
|
||||
if expected_ending:
|
||||
res["html"] = html_code
|
||||
else:
|
||||
logger.warning(f"table recognition processing fails, not found expected HTML table end")
|
||||
else:
|
||||
logger.warning(f"table recognition processing fails, not get latex or html return")
|
||||
logger.warning(f"table recognition processing fails, not get html return")
|
||||
logger.info(f"table time: {round(time.time() - table_start, 2)}")
|
||||
|
||||
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
|
||||
|
||||
return layout_res
|
||||
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import re
|
||||
|
||||
def layout_rm_equation(layout_res):
|
||||
rm_idxs = []
|
||||
for idx, ele in enumerate(layout_res['layout_dets']):
|
||||
if ele['category_id'] == 10:
|
||||
rm_idxs.append(idx)
|
||||
|
||||
for idx in rm_idxs[::-1]:
|
||||
del layout_res['layout_dets'][idx]
|
||||
return layout_res
|
||||
|
||||
|
||||
def get_croped_image(image_pil, bbox):
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
croped_img = image_pil.crop((x_min, y_min, x_max, y_max))
|
||||
return croped_img
|
||||
|
||||
|
||||
def latex_rm_whitespace(s: str):
|
||||
"""Remove unnecessary whitespace from LaTeX code.
|
||||
"""
|
||||
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
|
||||
letter = '[a-zA-Z]'
|
||||
noletter = '[\W_^\d]'
|
||||
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
|
||||
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
||||
news = s
|
||||
while True:
|
||||
s = news
|
||||
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
|
||||
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
|
||||
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
|
||||
if news == s:
|
||||
break
|
||||
return s
|
||||
@@ -1,388 +0,0 @@
|
||||
import time
|
||||
import copy
|
||||
import base64
|
||||
import cv2
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
from paddleocr.ppocr.utils.logging import get_logger
|
||||
from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
|
||||
from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
|
||||
|
||||
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
||||
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def img_decode(content: bytes):
|
||||
np_arr = np.frombuffer(content, dtype=np.uint8)
|
||||
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
|
||||
def check_img(img):
|
||||
if isinstance(img, bytes):
|
||||
img = img_decode(img)
|
||||
if isinstance(img, str):
|
||||
image_file = img
|
||||
img, flag_gif, flag_pdf = check_and_read(image_file)
|
||||
if not flag_gif and not flag_pdf:
|
||||
with open(image_file, 'rb') as f:
|
||||
img_str = f.read()
|
||||
img = img_decode(img_str)
|
||||
if img is None:
|
||||
try:
|
||||
buf = BytesIO()
|
||||
image = BytesIO(img_str)
|
||||
im = Image.open(image)
|
||||
rgb = im.convert('RGB')
|
||||
rgb.save(buf, 'jpeg')
|
||||
buf.seek(0)
|
||||
image_bytes = buf.read()
|
||||
data_base64 = str(base64.b64encode(image_bytes),
|
||||
encoding="utf-8")
|
||||
image_decode = base64.b64decode(data_base64)
|
||||
img_array = np.frombuffer(image_decode, np.uint8)
|
||||
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
||||
except:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
return None
|
||||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
return None
|
||||
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||
tmp = _boxes[j]
|
||||
_boxes[j] = _boxes[j + 1]
|
||||
_boxes[j + 1] = tmp
|
||||
else:
|
||||
break
|
||||
return _boxes
|
||||
|
||||
|
||||
def bbox_to_points(bbox):
|
||||
""" 将bbox格式转换为四个顶点的数组 """
|
||||
x0, y0, x1, y1 = bbox
|
||||
return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
|
||||
|
||||
|
||||
def points_to_bbox(points):
|
||||
""" 将四个顶点的数组转换为bbox格式 """
|
||||
x0, y0 = points[0]
|
||||
x1, _ = points[1]
|
||||
_, y1 = points[2]
|
||||
return [x0, y0, x1, y1]
|
||||
|
||||
|
||||
def merge_intervals(intervals):
|
||||
# Sort the intervals based on the start value
|
||||
intervals.sort(key=lambda x: x[0])
|
||||
|
||||
merged = []
|
||||
for interval in intervals:
|
||||
# If the list of merged intervals is empty or if the current
|
||||
# interval does not overlap with the previous, simply append it.
|
||||
if not merged or merged[-1][1] < interval[0]:
|
||||
merged.append(interval)
|
||||
else:
|
||||
# Otherwise, there is overlap, so we merge the current and previous intervals.
|
||||
merged[-1][1] = max(merged[-1][1], interval[1])
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def remove_intervals(original, masks):
|
||||
# Merge all mask intervals
|
||||
merged_masks = merge_intervals(masks)
|
||||
|
||||
result = []
|
||||
original_start, original_end = original
|
||||
|
||||
for mask in merged_masks:
|
||||
mask_start, mask_end = mask
|
||||
|
||||
# If the mask starts after the original range, ignore it
|
||||
if mask_start > original_end:
|
||||
continue
|
||||
|
||||
# If the mask ends before the original range starts, ignore it
|
||||
if mask_end < original_start:
|
||||
continue
|
||||
|
||||
# Remove the masked part from the original range
|
||||
if original_start < mask_start:
|
||||
result.append([original_start, mask_start - 1])
|
||||
|
||||
original_start = max(mask_end + 1, original_start)
|
||||
|
||||
# Add the remaining part of the original range, if any
|
||||
if original_start <= original_end:
|
||||
result.append([original_start, original_end])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def update_det_boxes(dt_boxes, mfd_res):
|
||||
new_dt_boxes = []
|
||||
for text_box in dt_boxes:
|
||||
text_bbox = points_to_bbox(text_box)
|
||||
masks_list = []
|
||||
for mf_box in mfd_res:
|
||||
mf_bbox = mf_box['bbox']
|
||||
if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
|
||||
masks_list.append([mf_bbox[0], mf_bbox[2]])
|
||||
text_x_range = [text_bbox[0], text_bbox[2]]
|
||||
text_remove_mask_range = remove_intervals(text_x_range, masks_list)
|
||||
temp_dt_box = []
|
||||
for text_remove_mask in text_remove_mask_range:
|
||||
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
|
||||
if len(temp_dt_box) > 0:
|
||||
new_dt_boxes.extend(temp_dt_box)
|
||||
return new_dt_boxes
|
||||
|
||||
|
||||
def merge_overlapping_spans(spans):
|
||||
"""
|
||||
Merges overlapping spans on the same line.
|
||||
|
||||
:param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
|
||||
:return: A list of merged spans
|
||||
"""
|
||||
# Return an empty list if the input spans list is empty
|
||||
if not spans:
|
||||
return []
|
||||
|
||||
# Sort spans by their starting x-coordinate
|
||||
spans.sort(key=lambda x: x[0])
|
||||
|
||||
# Initialize the list of merged spans
|
||||
merged = []
|
||||
for span in spans:
|
||||
# Unpack span coordinates
|
||||
x1, y1, x2, y2 = span
|
||||
# If the merged list is empty or there's no horizontal overlap, add the span directly
|
||||
if not merged or merged[-1][2] < x1:
|
||||
merged.append(span)
|
||||
else:
|
||||
# If there is horizontal overlap, merge the current span with the previous one
|
||||
last_span = merged.pop()
|
||||
# Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
|
||||
x1 = min(last_span[0], x1)
|
||||
y1 = min(last_span[1], y1)
|
||||
x2 = max(last_span[2], x2)
|
||||
y2 = max(last_span[3], y2)
|
||||
# Add the merged span back to the list
|
||||
merged.append((x1, y1, x2, y2))
|
||||
|
||||
# Return the list of merged spans
|
||||
return merged
|
||||
|
||||
|
||||
def merge_det_boxes(dt_boxes):
|
||||
"""
|
||||
Merge detection boxes.
|
||||
|
||||
This function takes a list of detected bounding boxes, each represented by four corner points.
|
||||
The goal is to merge these bounding boxes into larger text regions.
|
||||
|
||||
Parameters:
|
||||
dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
|
||||
|
||||
Returns:
|
||||
list: A list containing the merged text regions, where each region is represented by four corner points.
|
||||
"""
|
||||
# Convert the detection boxes into a dictionary format with bounding boxes and type
|
||||
dt_boxes_dict_list = []
|
||||
for text_box in dt_boxes:
|
||||
text_bbox = points_to_bbox(text_box)
|
||||
text_box_dict = {
|
||||
'bbox': text_bbox,
|
||||
'type': 'text',
|
||||
}
|
||||
dt_boxes_dict_list.append(text_box_dict)
|
||||
|
||||
# Merge adjacent text regions into lines
|
||||
lines = merge_spans_to_line(dt_boxes_dict_list)
|
||||
|
||||
# Initialize a new list for storing the merged text regions
|
||||
new_dt_boxes = []
|
||||
for line in lines:
|
||||
line_bbox_list = []
|
||||
for span in line:
|
||||
line_bbox_list.append(span['bbox'])
|
||||
|
||||
# Merge overlapping text regions within the same line
|
||||
merged_spans = merge_overlapping_spans(line_bbox_list)
|
||||
|
||||
# Convert the merged text regions back to point format and add them to the new detection box list
|
||||
for span in merged_spans:
|
||||
new_dt_boxes.append(bbox_to_points(span))
|
||||
|
||||
return new_dt_boxes
|
||||
|
||||
|
||||
class ModifiedPaddleOCR(PaddleOCR):
|
||||
def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
|
||||
"""
|
||||
OCR with PaddleOCR
|
||||
args:
|
||||
img: img for OCR, support ndarray, img_path and list or ndarray
|
||||
det: use text detection or not. If False, only rec will be exec. Default is True
|
||||
rec: use text recognition or not. If False, only det will be exec. Default is True
|
||||
cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
|
||||
bin: binarize image to black and white. Default is False.
|
||||
inv: invert image colors. Default is False.
|
||||
alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
|
||||
"""
|
||||
assert isinstance(img, (np.ndarray, list, str, bytes))
|
||||
if isinstance(img, list) and det == True:
|
||||
logger.error('When input a list of images, det must be false')
|
||||
exit(0)
|
||||
if cls == True and self.use_angle_cls == False:
|
||||
pass
|
||||
# logger.warning(
|
||||
# 'Since the angle classifier is not initialized, it will not be used during the forward process'
|
||||
# )
|
||||
|
||||
img = check_img(img)
|
||||
# for infer pdf file
|
||||
if isinstance(img, list):
|
||||
if self.page_num > len(img) or self.page_num == 0:
|
||||
self.page_num = len(img)
|
||||
imgs = img[:self.page_num]
|
||||
else:
|
||||
imgs = [img]
|
||||
|
||||
def preprocess_image(_image):
|
||||
_image = alpha_to_color(_image, alpha_color)
|
||||
if inv:
|
||||
_image = cv2.bitwise_not(_image)
|
||||
if bin:
|
||||
_image = binarize_img(_image)
|
||||
return _image
|
||||
|
||||
if det and rec:
|
||||
ocr_res = []
|
||||
for idx, img in enumerate(imgs):
|
||||
img = preprocess_image(img)
|
||||
dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
|
||||
if not dt_boxes and not rec_res:
|
||||
ocr_res.append(None)
|
||||
continue
|
||||
tmp_res = [[box.tolist(), res]
|
||||
for box, res in zip(dt_boxes, rec_res)]
|
||||
ocr_res.append(tmp_res)
|
||||
return ocr_res
|
||||
elif det and not rec:
|
||||
ocr_res = []
|
||||
for idx, img in enumerate(imgs):
|
||||
img = preprocess_image(img)
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
if not dt_boxes:
|
||||
ocr_res.append(None)
|
||||
continue
|
||||
tmp_res = [box.tolist() for box in dt_boxes]
|
||||
ocr_res.append(tmp_res)
|
||||
return ocr_res
|
||||
else:
|
||||
ocr_res = []
|
||||
cls_res = []
|
||||
for idx, img in enumerate(imgs):
|
||||
if not isinstance(img, list):
|
||||
img = preprocess_image(img)
|
||||
img = [img]
|
||||
if self.use_angle_cls and cls:
|
||||
img, cls_res_tmp, elapse = self.text_classifier(img)
|
||||
if not rec:
|
||||
cls_res.append(cls_res_tmp)
|
||||
rec_res, elapse = self.text_recognizer(img)
|
||||
ocr_res.append(rec_res)
|
||||
if not rec:
|
||||
return cls_res
|
||||
return ocr_res
|
||||
|
||||
def __call__(self, img, cls=True, mfd_res=None):
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
|
||||
if img is None:
|
||||
logger.debug("no valid image provided")
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
time_dict['det'] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return None, None, time_dict
|
||||
else:
|
||||
logger.debug("dt_boxes num : {}, elapsed : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
img_crop_list = []
|
||||
|
||||
dt_boxes = sorted_boxes(dt_boxes)
|
||||
|
||||
dt_boxes = merge_det_boxes(dt_boxes)
|
||||
|
||||
if mfd_res:
|
||||
bef = time.time()
|
||||
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
||||
aft = time.time()
|
||||
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
|
||||
len(dt_boxes), aft - bef))
|
||||
|
||||
for bno in range(len(dt_boxes)):
|
||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
if self.args.det_box_type == "quad":
|
||||
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
||||
else:
|
||||
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
if self.use_angle_cls and cls:
|
||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||
img_crop_list)
|
||||
time_dict['cls'] = elapse
|
||||
logger.debug("cls num : {}, elapsed : {}".format(
|
||||
len(img_crop_list), elapse))
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
time_dict['rec'] = elapse
|
||||
logger.debug("rec_res num : {}, elapsed : {}".format(
|
||||
len(rec_res), elapse))
|
||||
if self.args.save_crop_res:
|
||||
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
|
||||
rec_res)
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
for box, rec_result in zip(dt_boxes, rec_res):
|
||||
text, score = rec_result
|
||||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_result)
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return filter_boxes, filter_rec_res, time_dict
|
||||
@@ -0,0 +1,21 @@
|
||||
from doclayout_yolo import YOLOv10
|
||||
|
||||
|
||||
class DocLayoutYOLOModel(object):
|
||||
def __init__(self, weight, device):
|
||||
self.model = YOLOv10(weight)
|
||||
self.device = device
|
||||
|
||||
def predict(self, image):
|
||||
layout_res = []
|
||||
doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
||||
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
|
||||
doclayout_yolo_res.boxes.cls.cpu()):
|
||||
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
||||
new_item = {
|
||||
'category_id': int(cla.item()),
|
||||
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
||||
'score': round(float(conf.item()), 3),
|
||||
}
|
||||
layout_res.append(new_item)
|
||||
return layout_res
|
||||
0
magic_pdf/model/sub_modules/mfd/__init__.py
Normal file
0
magic_pdf/model/sub_modules/mfd/__init__.py
Normal file
12
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
Normal file
12
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
class YOLOv8MFDModel(object):
|
||||
def __init__(self, weight, device='cpu'):
|
||||
self.mfd_model = YOLO(weight)
|
||||
self.device = device
|
||||
|
||||
def predict(self, image):
|
||||
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
||||
return mfd_res
|
||||
|
||||
0
magic_pdf/model/sub_modules/mfd/yolov8/__init__.py
Normal file
0
magic_pdf/model/sub_modules/mfd/yolov8/__init__.py
Normal file
0
magic_pdf/model/sub_modules/mfr/__init__.py
Normal file
0
magic_pdf/model/sub_modules/mfr/__init__.py
Normal file
98
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
Normal file
98
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import os
|
||||
import argparse
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from unimernet.common.config import Config
|
||||
import unimernet.tasks as tasks
|
||||
from unimernet.processors import load_processor
|
||||
|
||||
|
||||
class MathDataset(Dataset):
|
||||
def __init__(self, image_paths, transform=None):
|
||||
self.image_paths = image_paths
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# if not pil image, then convert to pil image
|
||||
if isinstance(self.image_paths[idx], str):
|
||||
raw_image = Image.open(self.image_paths[idx])
|
||||
else:
|
||||
raw_image = self.image_paths[idx]
|
||||
if self.transform:
|
||||
image = self.transform(raw_image)
|
||||
return image
|
||||
|
||||
|
||||
def latex_rm_whitespace(s: str):
|
||||
"""Remove unnecessary whitespace from LaTeX code.
|
||||
"""
|
||||
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
|
||||
letter = '[a-zA-Z]'
|
||||
noletter = '[\W_^\d]'
|
||||
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
|
||||
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
||||
news = s
|
||||
while True:
|
||||
s = news
|
||||
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
|
||||
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
|
||||
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
|
||||
if news == s:
|
||||
break
|
||||
return s
|
||||
|
||||
|
||||
class UnimernetModel(object):
|
||||
def __init__(self, weight_dir, cfg_path, _device_='cpu'):
|
||||
|
||||
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
||||
cfg = Config(args)
|
||||
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
||||
cfg.config.model.model_config.model_name = weight_dir
|
||||
cfg.config.model.tokenizer_config.path = weight_dir
|
||||
task = tasks.setup_task(cfg)
|
||||
self.model = task.build_model(cfg)
|
||||
self.device = _device_
|
||||
self.model.to(_device_)
|
||||
self.model.eval()
|
||||
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
||||
self.mfr_transform = transforms.Compose([vis_processor, ])
|
||||
|
||||
def predict(self, mfd_res, image):
|
||||
|
||||
formula_list = []
|
||||
mf_image_list = []
|
||||
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
||||
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
||||
new_item = {
|
||||
'category_id': 13 + int(cla.item()),
|
||||
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
||||
'score': round(float(conf.item()), 2),
|
||||
'latex': '',
|
||||
}
|
||||
formula_list.append(new_item)
|
||||
pil_img = Image.fromarray(image)
|
||||
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
||||
mf_image_list.append(bbox_img)
|
||||
|
||||
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
||||
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
||||
mfr_res = []
|
||||
for mf_img in dataloader:
|
||||
mf_img = mf_img.to(self.device)
|
||||
with torch.no_grad():
|
||||
output = self.model.generate({'image': mf_img})
|
||||
mfr_res.extend(output['pred_str'])
|
||||
for res, latex in zip(formula_list, mfr_res):
|
||||
res['latex'] = latex_rm_whitespace(latex)
|
||||
return formula_list
|
||||
|
||||
|
||||
|
||||
144
magic_pdf/model/sub_modules/model_init.py
Normal file
144
magic_pdf/model/sub_modules/model_init.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from loguru import logger
|
||||
|
||||
from magic_pdf.libs.Constants import MODEL_NAME
|
||||
from magic_pdf.model.model_list import AtomicModel
|
||||
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
|
||||
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
|
||||
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
|
||||
|
||||
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
|
||||
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
|
||||
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
|
||||
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
|
||||
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
|
||||
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
|
||||
|
||||
|
||||
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
||||
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
||||
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
||||
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
||||
config = {
|
||||
"model_dir": model_path,
|
||||
"device": _device_
|
||||
}
|
||||
table_model = TableMasterPaddleModel(config)
|
||||
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
||||
table_model = RapidTableModel()
|
||||
else:
|
||||
logger.error("table model type not allow")
|
||||
exit(1)
|
||||
|
||||
return table_model
|
||||
|
||||
|
||||
def mfd_model_init(weight, device='cpu'):
|
||||
mfd_model = YOLOv8MFDModel(weight, device)
|
||||
return mfd_model
|
||||
|
||||
|
||||
def mfr_model_init(weight_dir, cfg_path, device='cpu'):
|
||||
mfr_model = UnimernetModel(weight_dir, cfg_path, device)
|
||||
return mfr_model
|
||||
|
||||
|
||||
def layout_model_init(weight, config_file, device):
|
||||
model = Layoutlmv3_Predictor(weight, config_file, device)
|
||||
return model
|
||||
|
||||
|
||||
def doclayout_yolo_model_init(weight, device='cpu'):
|
||||
model = DocLayoutYOLOModel(weight, device)
|
||||
return model
|
||||
|
||||
|
||||
def ocr_model_init(show_log: bool = False,
|
||||
det_db_box_thresh=0.3,
|
||||
lang=None,
|
||||
use_dilation=True,
|
||||
det_db_unclip_ratio=1.8,
|
||||
):
|
||||
if lang is not None:
|
||||
model = ModifiedPaddleOCR(
|
||||
show_log=show_log,
|
||||
det_db_box_thresh=det_db_box_thresh,
|
||||
lang=lang,
|
||||
use_dilation=use_dilation,
|
||||
det_db_unclip_ratio=det_db_unclip_ratio,
|
||||
)
|
||||
else:
|
||||
model = ModifiedPaddleOCR(
|
||||
show_log=show_log,
|
||||
det_db_box_thresh=det_db_box_thresh,
|
||||
use_dilation=use_dilation,
|
||||
det_db_unclip_ratio=det_db_unclip_ratio,
|
||||
# use_angle_cls=True,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
class AtomModelSingleton:
|
||||
_instance = None
|
||||
_models = {}
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def get_atom_model(self, atom_model_name: str, **kwargs):
|
||||
lang = kwargs.get("lang", None)
|
||||
layout_model_name = kwargs.get("layout_model_name", None)
|
||||
key = (atom_model_name, layout_model_name, lang)
|
||||
if key not in self._models:
|
||||
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
||||
return self._models[key]
|
||||
|
||||
|
||||
def atom_model_init(model_name: str, **kwargs):
|
||||
atom_model = None
|
||||
if model_name == AtomicModel.Layout:
|
||||
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
|
||||
atom_model = layout_model_init(
|
||||
kwargs.get("layout_weights"),
|
||||
kwargs.get("layout_config_file"),
|
||||
kwargs.get("device")
|
||||
)
|
||||
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
|
||||
atom_model = doclayout_yolo_model_init(
|
||||
kwargs.get("doclayout_yolo_weights"),
|
||||
kwargs.get("device")
|
||||
)
|
||||
elif model_name == AtomicModel.MFD:
|
||||
atom_model = mfd_model_init(
|
||||
kwargs.get("mfd_weights"),
|
||||
kwargs.get("device")
|
||||
)
|
||||
elif model_name == AtomicModel.MFR:
|
||||
atom_model = mfr_model_init(
|
||||
kwargs.get("mfr_weight_dir"),
|
||||
kwargs.get("mfr_cfg_path"),
|
||||
kwargs.get("device")
|
||||
)
|
||||
elif model_name == AtomicModel.OCR:
|
||||
atom_model = ocr_model_init(
|
||||
kwargs.get("ocr_show_log"),
|
||||
kwargs.get("det_db_box_thresh"),
|
||||
kwargs.get("lang")
|
||||
)
|
||||
elif model_name == AtomicModel.Table:
|
||||
atom_model = table_model_init(
|
||||
kwargs.get("table_model_name"),
|
||||
kwargs.get("table_model_path"),
|
||||
kwargs.get("table_max_time"),
|
||||
kwargs.get("device")
|
||||
)
|
||||
else:
|
||||
logger.error("model name not allow")
|
||||
exit(1)
|
||||
|
||||
if atom_model is None:
|
||||
logger.error("model init failed")
|
||||
exit(1)
|
||||
else:
|
||||
return atom_model
|
||||
51
magic_pdf/model/sub_modules/model_utils.py
Normal file
51
magic_pdf/model/sub_modules/model_utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
|
||||
from magic_pdf.libs.clean_memory import clean_memory
|
||||
|
||||
|
||||
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
||||
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
||||
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
||||
# Create a white background with an additional width and height of 50
|
||||
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
||||
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
||||
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
||||
|
||||
# Crop image
|
||||
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
||||
cropped_img = input_pil_img.crop(crop_box)
|
||||
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
||||
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
||||
return return_image, return_list
|
||||
|
||||
|
||||
# Select regions for OCR / formula regions / table regions
|
||||
def get_res_list_from_layout_res(layout_res):
|
||||
ocr_res_list = []
|
||||
table_res_list = []
|
||||
single_page_mfdetrec_res = []
|
||||
for res in layout_res:
|
||||
if int(res['category_id']) in [13, 14]:
|
||||
single_page_mfdetrec_res.append({
|
||||
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
||||
int(res['poly'][4]), int(res['poly'][5])],
|
||||
})
|
||||
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
||||
ocr_res_list.append(res)
|
||||
elif int(res['category_id']) in [5]:
|
||||
table_res_list.append(res)
|
||||
return ocr_res_list, table_res_list, single_page_mfdetrec_res
|
||||
|
||||
|
||||
def clean_vram(device, vram_threshold=8):
|
||||
if torch.cuda.is_available() and device != 'cpu':
|
||||
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
||||
if total_memory <= vram_threshold:
|
||||
gc_start = time.time()
|
||||
clean_memory()
|
||||
gc_time = round(time.time() - gc_start, 2)
|
||||
logger.info(f"gc time: {gc_time}")
|
||||
0
magic_pdf/model/sub_modules/ocr/__init__.py
Normal file
0
magic_pdf/model/sub_modules/ocr/__init__.py
Normal file
259
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
Normal file
259
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
||||
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
|
||||
|
||||
|
||||
def bbox_to_points(bbox):
|
||||
""" 将bbox格式转换为四个顶点的数组 """
|
||||
x0, y0, x1, y1 = bbox
|
||||
return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
|
||||
|
||||
|
||||
def points_to_bbox(points):
|
||||
""" 将四个顶点的数组转换为bbox格式 """
|
||||
x0, y0 = points[0]
|
||||
x1, _ = points[1]
|
||||
_, y1 = points[2]
|
||||
return [x0, y0, x1, y1]
|
||||
|
||||
|
||||
def merge_intervals(intervals):
|
||||
# Sort the intervals based on the start value
|
||||
intervals.sort(key=lambda x: x[0])
|
||||
|
||||
merged = []
|
||||
for interval in intervals:
|
||||
# If the list of merged intervals is empty or if the current
|
||||
# interval does not overlap with the previous, simply append it.
|
||||
if not merged or merged[-1][1] < interval[0]:
|
||||
merged.append(interval)
|
||||
else:
|
||||
# Otherwise, there is overlap, so we merge the current and previous intervals.
|
||||
merged[-1][1] = max(merged[-1][1], interval[1])
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def remove_intervals(original, masks):
|
||||
# Merge all mask intervals
|
||||
merged_masks = merge_intervals(masks)
|
||||
|
||||
result = []
|
||||
original_start, original_end = original
|
||||
|
||||
for mask in merged_masks:
|
||||
mask_start, mask_end = mask
|
||||
|
||||
# If the mask starts after the original range, ignore it
|
||||
if mask_start > original_end:
|
||||
continue
|
||||
|
||||
# If the mask ends before the original range starts, ignore it
|
||||
if mask_end < original_start:
|
||||
continue
|
||||
|
||||
# Remove the masked part from the original range
|
||||
if original_start < mask_start:
|
||||
result.append([original_start, mask_start - 1])
|
||||
|
||||
original_start = max(mask_end + 1, original_start)
|
||||
|
||||
# Add the remaining part of the original range, if any
|
||||
if original_start <= original_end:
|
||||
result.append([original_start, original_end])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def update_det_boxes(dt_boxes, mfd_res):
|
||||
new_dt_boxes = []
|
||||
for text_box in dt_boxes:
|
||||
text_bbox = points_to_bbox(text_box)
|
||||
masks_list = []
|
||||
for mf_box in mfd_res:
|
||||
mf_bbox = mf_box['bbox']
|
||||
if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
|
||||
masks_list.append([mf_bbox[0], mf_bbox[2]])
|
||||
text_x_range = [text_bbox[0], text_bbox[2]]
|
||||
text_remove_mask_range = remove_intervals(text_x_range, masks_list)
|
||||
temp_dt_box = []
|
||||
for text_remove_mask in text_remove_mask_range:
|
||||
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
|
||||
if len(temp_dt_box) > 0:
|
||||
new_dt_boxes.extend(temp_dt_box)
|
||||
return new_dt_boxes
|
||||
|
||||
|
||||
def merge_overlapping_spans(spans):
|
||||
"""
|
||||
Merges overlapping spans on the same line.
|
||||
|
||||
:param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
|
||||
:return: A list of merged spans
|
||||
"""
|
||||
# Return an empty list if the input spans list is empty
|
||||
if not spans:
|
||||
return []
|
||||
|
||||
# Sort spans by their starting x-coordinate
|
||||
spans.sort(key=lambda x: x[0])
|
||||
|
||||
# Initialize the list of merged spans
|
||||
merged = []
|
||||
for span in spans:
|
||||
# Unpack span coordinates
|
||||
x1, y1, x2, y2 = span
|
||||
# If the merged list is empty or there's no horizontal overlap, add the span directly
|
||||
if not merged or merged[-1][2] < x1:
|
||||
merged.append(span)
|
||||
else:
|
||||
# If there is horizontal overlap, merge the current span with the previous one
|
||||
last_span = merged.pop()
|
||||
# Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
|
||||
x1 = min(last_span[0], x1)
|
||||
y1 = min(last_span[1], y1)
|
||||
x2 = max(last_span[2], x2)
|
||||
y2 = max(last_span[3], y2)
|
||||
# Add the merged span back to the list
|
||||
merged.append((x1, y1, x2, y2))
|
||||
|
||||
# Return the list of merged spans
|
||||
return merged
|
||||
|
||||
|
||||
def merge_det_boxes(dt_boxes):
|
||||
"""
|
||||
Merge detection boxes.
|
||||
|
||||
This function takes a list of detected bounding boxes, each represented by four corner points.
|
||||
The goal is to merge these bounding boxes into larger text regions.
|
||||
|
||||
Parameters:
|
||||
dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
|
||||
|
||||
Returns:
|
||||
list: A list containing the merged text regions, where each region is represented by four corner points.
|
||||
"""
|
||||
# Convert the detection boxes into a dictionary format with bounding boxes and type
|
||||
dt_boxes_dict_list = []
|
||||
angle_boxes_list = []
|
||||
for text_box in dt_boxes:
|
||||
text_bbox = points_to_bbox(text_box)
|
||||
if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]:
|
||||
angle_boxes_list.append(text_box)
|
||||
continue
|
||||
text_box_dict = {
|
||||
'bbox': text_bbox,
|
||||
'type': 'text',
|
||||
}
|
||||
dt_boxes_dict_list.append(text_box_dict)
|
||||
|
||||
# Merge adjacent text regions into lines
|
||||
lines = merge_spans_to_line(dt_boxes_dict_list)
|
||||
|
||||
# Initialize a new list for storing the merged text regions
|
||||
new_dt_boxes = []
|
||||
for line in lines:
|
||||
line_bbox_list = []
|
||||
for span in line:
|
||||
line_bbox_list.append(span['bbox'])
|
||||
|
||||
# Merge overlapping text regions within the same line
|
||||
merged_spans = merge_overlapping_spans(line_bbox_list)
|
||||
|
||||
# Convert the merged text regions back to point format and add them to the new detection box list
|
||||
for span in merged_spans:
|
||||
new_dt_boxes.append(bbox_to_points(span))
|
||||
|
||||
new_dt_boxes.extend(angle_boxes_list)
|
||||
|
||||
return new_dt_boxes
|
||||
|
||||
|
||||
def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
|
||||
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
||||
# Adjust the coordinates of the formula area
|
||||
adjusted_mfdetrec_res = []
|
||||
for mf_res in single_page_mfdetrec_res:
|
||||
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
||||
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
||||
x0 = mf_xmin - xmin + paste_x
|
||||
y0 = mf_ymin - ymin + paste_y
|
||||
x1 = mf_xmax - xmin + paste_x
|
||||
y1 = mf_ymax - ymin + paste_y
|
||||
# Filter formula blocks outside the graph
|
||||
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
||||
continue
|
||||
else:
|
||||
adjusted_mfdetrec_res.append({
|
||||
"bbox": [x0, y0, x1, y1],
|
||||
})
|
||||
return adjusted_mfdetrec_res
|
||||
|
||||
|
||||
def get_ocr_result_list(ocr_res, useful_list):
|
||||
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
||||
ocr_result_list = []
|
||||
for box_ocr_res in ocr_res:
|
||||
|
||||
p1, p2, p3, p4 = box_ocr_res[0]
|
||||
text, score = box_ocr_res[1]
|
||||
average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
|
||||
if average_angle_degrees > 0.5:
|
||||
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
|
||||
# 与x轴的夹角超过0.5度,对边界做一下矫正
|
||||
# 计算几何中心
|
||||
x_center = sum(point[0] for point in box_ocr_res[0]) / 4
|
||||
y_center = sum(point[1] for point in box_ocr_res[0]) / 4
|
||||
new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
|
||||
new_width = p3[0] - p1[0]
|
||||
p1 = [x_center - new_width / 2, y_center - new_height / 2]
|
||||
p2 = [x_center + new_width / 2, y_center - new_height / 2]
|
||||
p3 = [x_center + new_width / 2, y_center + new_height / 2]
|
||||
p4 = [x_center - new_width / 2, y_center + new_height / 2]
|
||||
|
||||
# Convert the coordinates back to the original coordinate system
|
||||
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
||||
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
||||
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
||||
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
||||
|
||||
ocr_result_list.append({
|
||||
'category_id': 15,
|
||||
'poly': p1 + p2 + p3 + p4,
|
||||
'score': float(round(score, 2)),
|
||||
'text': text,
|
||||
})
|
||||
|
||||
return ocr_result_list
|
||||
|
||||
|
||||
def calculate_angle_degrees(poly):
|
||||
# 定义对角线的顶点
|
||||
diagonal1 = (poly[0], poly[2])
|
||||
diagonal2 = (poly[1], poly[3])
|
||||
|
||||
# 计算对角线的斜率
|
||||
def slope(p1, p2):
|
||||
return (p2[1] - p1[1]) / (p2[0] - p1[0]) if p2[0] != p1[0] else float('inf')
|
||||
|
||||
slope1 = slope(diagonal1[0], diagonal1[1])
|
||||
slope2 = slope(diagonal2[0], diagonal2[1])
|
||||
|
||||
# 计算对角线与x轴的夹角(以弧度为单位)
|
||||
angle1_radians = math.atan(slope1)
|
||||
angle2_radians = math.atan(slope2)
|
||||
|
||||
# 将弧度转换为角度
|
||||
angle1_degrees = math.degrees(angle1_radians)
|
||||
angle2_degrees = math.degrees(angle2_radians)
|
||||
|
||||
# 取两条对角线与x轴夹角的平均值
|
||||
average_angle_degrees = abs((angle1_degrees + angle2_degrees) / 2)
|
||||
# logger.info(f"average_angle_degrees: {average_angle_degrees}")
|
||||
return average_angle_degrees
|
||||
|
||||
168
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
Normal file
168
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import copy
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from paddleocr import PaddleOCR
|
||||
from paddleocr.paddleocr import check_img, logger
|
||||
from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
|
||||
from paddleocr.tools.infer.predict_system import sorted_boxes
|
||||
from paddleocr.tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
|
||||
|
||||
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes
|
||||
|
||||
|
||||
class ModifiedPaddleOCR(PaddleOCR):
|
||||
def ocr(self,
|
||||
img,
|
||||
det=True,
|
||||
rec=True,
|
||||
cls=True,
|
||||
bin=False,
|
||||
inv=False,
|
||||
alpha_color=(255, 255, 255),
|
||||
mfd_res=None,
|
||||
):
|
||||
"""
|
||||
OCR with PaddleOCR
|
||||
args:
|
||||
img: img for OCR, support ndarray, img_path and list or ndarray
|
||||
det: use text detection or not. If False, only rec will be exec. Default is True
|
||||
rec: use text recognition or not. If False, only det will be exec. Default is True
|
||||
cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
|
||||
bin: binarize image to black and white. Default is False.
|
||||
inv: invert image colors. Default is False.
|
||||
alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
|
||||
"""
|
||||
assert isinstance(img, (np.ndarray, list, str, bytes))
|
||||
if isinstance(img, list) and det == True:
|
||||
logger.error('When input a list of images, det must be false')
|
||||
exit(0)
|
||||
if cls == True and self.use_angle_cls == False:
|
||||
pass
|
||||
# logger.warning(
|
||||
# 'Since the angle classifier is not initialized, it will not be used during the forward process'
|
||||
# )
|
||||
|
||||
img = check_img(img)
|
||||
# for infer pdf file
|
||||
if isinstance(img, list):
|
||||
if self.page_num > len(img) or self.page_num == 0:
|
||||
self.page_num = len(img)
|
||||
imgs = img[:self.page_num]
|
||||
else:
|
||||
imgs = [img]
|
||||
|
||||
def preprocess_image(_image):
|
||||
_image = alpha_to_color(_image, alpha_color)
|
||||
if inv:
|
||||
_image = cv2.bitwise_not(_image)
|
||||
if bin:
|
||||
_image = binarize_img(_image)
|
||||
return _image
|
||||
|
||||
if det and rec:
|
||||
ocr_res = []
|
||||
for idx, img in enumerate(imgs):
|
||||
img = preprocess_image(img)
|
||||
dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
|
||||
if not dt_boxes and not rec_res:
|
||||
ocr_res.append(None)
|
||||
continue
|
||||
tmp_res = [[box.tolist(), res]
|
||||
for box, res in zip(dt_boxes, rec_res)]
|
||||
ocr_res.append(tmp_res)
|
||||
return ocr_res
|
||||
elif det and not rec:
|
||||
ocr_res = []
|
||||
for idx, img in enumerate(imgs):
|
||||
img = preprocess_image(img)
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
if not dt_boxes:
|
||||
ocr_res.append(None)
|
||||
continue
|
||||
tmp_res = [box.tolist() for box in dt_boxes]
|
||||
ocr_res.append(tmp_res)
|
||||
return ocr_res
|
||||
else:
|
||||
ocr_res = []
|
||||
cls_res = []
|
||||
for idx, img in enumerate(imgs):
|
||||
if not isinstance(img, list):
|
||||
img = preprocess_image(img)
|
||||
img = [img]
|
||||
if self.use_angle_cls and cls:
|
||||
img, cls_res_tmp, elapse = self.text_classifier(img)
|
||||
if not rec:
|
||||
cls_res.append(cls_res_tmp)
|
||||
rec_res, elapse = self.text_recognizer(img)
|
||||
ocr_res.append(rec_res)
|
||||
if not rec:
|
||||
return cls_res
|
||||
return ocr_res
|
||||
|
||||
def __call__(self, img, cls=True, mfd_res=None):
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
|
||||
if img is None:
|
||||
logger.debug("no valid image provided")
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
time_dict['det'] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return None, None, time_dict
|
||||
else:
|
||||
logger.debug("dt_boxes num : {}, elapsed : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
img_crop_list = []
|
||||
|
||||
dt_boxes = sorted_boxes(dt_boxes)
|
||||
|
||||
# @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge
|
||||
# dt_boxes = merge_det_boxes(dt_boxes)
|
||||
|
||||
|
||||
if mfd_res:
|
||||
bef = time.time()
|
||||
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
||||
aft = time.time()
|
||||
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
|
||||
len(dt_boxes), aft - bef))
|
||||
|
||||
for bno in range(len(dt_boxes)):
|
||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
if self.args.det_box_type == "quad":
|
||||
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
||||
else:
|
||||
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
if self.use_angle_cls and cls:
|
||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||
img_crop_list)
|
||||
time_dict['cls'] = elapse
|
||||
logger.debug("cls num : {}, elapsed : {}".format(
|
||||
len(img_crop_list), elapse))
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
time_dict['rec'] = elapse
|
||||
logger.debug("rec_res num : {}, elapsed : {}".format(
|
||||
len(rec_res), elapse))
|
||||
if self.args.save_crop_res:
|
||||
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
|
||||
rec_res)
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
for box, rec_result in zip(dt_boxes, rec_res):
|
||||
text, score = rec_result
|
||||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_result)
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return filter_boxes, filter_rec_res, time_dict
|
||||
213
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py
Normal file
213
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import copy
|
||||
import time
|
||||
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from paddleocr import PaddleOCR
|
||||
from paddleocr.paddleocr import check_img, logger
|
||||
from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
|
||||
from paddleocr.tools.infer.predict_system import sorted_boxes
|
||||
from paddleocr.tools.infer.utility import slice_generator, merge_fragmented, get_rotate_crop_image, \
|
||||
get_minarea_rect_crop
|
||||
|
||||
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes
|
||||
|
||||
|
||||
class ModifiedPaddleOCR(PaddleOCR):
|
||||
|
||||
def ocr(
|
||||
self,
|
||||
img,
|
||||
det=True,
|
||||
rec=True,
|
||||
cls=True,
|
||||
bin=False,
|
||||
inv=False,
|
||||
alpha_color=(255, 255, 255),
|
||||
slice={},
|
||||
mfd_res=None,
|
||||
):
|
||||
"""
|
||||
OCR with PaddleOCR
|
||||
|
||||
Args:
|
||||
img: Image for OCR. It can be an ndarray, img_path, or a list of ndarrays.
|
||||
det: Use text detection or not. If False, only text recognition will be executed. Default is True.
|
||||
rec: Use text recognition or not. If False, only text detection will be executed. Default is True.
|
||||
cls: Use angle classifier or not. Default is True. If True, the text with a rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance.
|
||||
bin: Binarize image to black and white. Default is False.
|
||||
inv: Invert image colors. Default is False.
|
||||
alpha_color: Set RGB color Tuple for transparent parts replacement. Default is pure white.
|
||||
slice: Use sliding window inference for large images. Both det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres"] (See doc/doc_en/slice_en.md). Default is {}.
|
||||
|
||||
Returns:
|
||||
If both det and rec are True, returns a list of OCR results for each image. Each OCR result is a list of bounding boxes and recognized text for each detected text region.
|
||||
If det is True and rec is False, returns a list of detected bounding boxes for each image.
|
||||
If det is False and rec is True, returns a list of recognized text for each image.
|
||||
If both det and rec are False, returns a list of angle classification results for each image.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the input image is not of type ndarray, list, str, or bytes.
|
||||
SystemExit: If det is True and the input is a list of images.
|
||||
|
||||
Note:
|
||||
- If the angle classifier is not initialized (use_angle_cls=False), it will not be used during the forward process.
|
||||
- For PDF files, if the input is a list of images and the page_num is specified, only the first page_num images will be processed.
|
||||
- The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified.
|
||||
"""
|
||||
assert isinstance(img, (np.ndarray, list, str, bytes))
|
||||
if isinstance(img, list) and det == True:
|
||||
logger.error("When input a list of images, det must be false")
|
||||
exit(0)
|
||||
if cls == True and self.use_angle_cls == False:
|
||||
logger.warning(
|
||||
"Since the angle classifier is not initialized, it will not be used during the forward process"
|
||||
)
|
||||
|
||||
img, flag_gif, flag_pdf = check_img(img, alpha_color)
|
||||
# for infer pdf file
|
||||
if isinstance(img, list) and flag_pdf:
|
||||
if self.page_num > len(img) or self.page_num == 0:
|
||||
imgs = img
|
||||
else:
|
||||
imgs = img[: self.page_num]
|
||||
else:
|
||||
imgs = [img]
|
||||
|
||||
def preprocess_image(_image):
|
||||
_image = alpha_to_color(_image, alpha_color)
|
||||
if inv:
|
||||
_image = cv2.bitwise_not(_image)
|
||||
if bin:
|
||||
_image = binarize_img(_image)
|
||||
return _image
|
||||
|
||||
if det and rec:
|
||||
ocr_res = []
|
||||
for img in imgs:
|
||||
img = preprocess_image(img)
|
||||
dt_boxes, rec_res, _ = self.__call__(img, cls, slice, mfd_res=mfd_res)
|
||||
if not dt_boxes and not rec_res:
|
||||
ocr_res.append(None)
|
||||
continue
|
||||
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||||
ocr_res.append(tmp_res)
|
||||
return ocr_res
|
||||
elif det and not rec:
|
||||
ocr_res = []
|
||||
for img in imgs:
|
||||
img = preprocess_image(img)
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
if dt_boxes.size == 0:
|
||||
ocr_res.append(None)
|
||||
continue
|
||||
tmp_res = [box.tolist() for box in dt_boxes]
|
||||
ocr_res.append(tmp_res)
|
||||
return ocr_res
|
||||
else:
|
||||
ocr_res = []
|
||||
cls_res = []
|
||||
for img in imgs:
|
||||
if not isinstance(img, list):
|
||||
img = preprocess_image(img)
|
||||
img = [img]
|
||||
if self.use_angle_cls and cls:
|
||||
img, cls_res_tmp, elapse = self.text_classifier(img)
|
||||
if not rec:
|
||||
cls_res.append(cls_res_tmp)
|
||||
rec_res, elapse = self.text_recognizer(img)
|
||||
ocr_res.append(rec_res)
|
||||
if not rec:
|
||||
return cls_res
|
||||
return ocr_res
|
||||
|
||||
def __call__(self, img, cls=True, slice={}, mfd_res=None):
|
||||
time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
|
||||
|
||||
if img is None:
|
||||
logger.debug("no valid image provided")
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
ori_im = img.copy()
|
||||
if slice:
|
||||
slice_gen = slice_generator(
|
||||
img,
|
||||
horizontal_stride=slice["horizontal_stride"],
|
||||
vertical_stride=slice["vertical_stride"],
|
||||
)
|
||||
elapsed = []
|
||||
dt_slice_boxes = []
|
||||
for slice_crop, v_start, h_start in slice_gen:
|
||||
dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
|
||||
if dt_boxes.size:
|
||||
dt_boxes[:, :, 0] += h_start
|
||||
dt_boxes[:, :, 1] += v_start
|
||||
dt_slice_boxes.append(dt_boxes)
|
||||
elapsed.append(elapse)
|
||||
dt_boxes = np.concatenate(dt_slice_boxes)
|
||||
|
||||
dt_boxes = merge_fragmented(
|
||||
boxes=dt_boxes,
|
||||
x_threshold=slice["merge_x_thres"],
|
||||
y_threshold=slice["merge_y_thres"],
|
||||
)
|
||||
elapse = sum(elapsed)
|
||||
else:
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
|
||||
time_dict["det"] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
|
||||
end = time.time()
|
||||
time_dict["all"] = end - start
|
||||
return None, None, time_dict
|
||||
else:
|
||||
logger.debug(
|
||||
"dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
|
||||
)
|
||||
img_crop_list = []
|
||||
|
||||
dt_boxes = sorted_boxes(dt_boxes)
|
||||
|
||||
if mfd_res:
|
||||
bef = time.time()
|
||||
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
||||
aft = time.time()
|
||||
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
|
||||
len(dt_boxes), aft - bef))
|
||||
|
||||
for bno in range(len(dt_boxes)):
|
||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
if self.args.det_box_type == "quad":
|
||||
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
||||
else:
|
||||
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
if self.use_angle_cls and cls:
|
||||
img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
|
||||
time_dict["cls"] = elapse
|
||||
logger.debug(
|
||||
"cls num : {}, elapsed : {}".format(len(img_crop_list), elapse)
|
||||
)
|
||||
if len(img_crop_list) > 1000:
|
||||
logger.debug(
|
||||
f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
|
||||
)
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
time_dict["rec"] = elapse
|
||||
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
|
||||
if self.args.save_crop_res:
|
||||
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
for box, rec_result in zip(dt_boxes, rec_res):
|
||||
text, score = rec_result[0], rec_result[1]
|
||||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_result)
|
||||
end = time.time()
|
||||
time_dict["all"] = end - start
|
||||
return filter_boxes, filter_rec_res, time_dict
|
||||
0
magic_pdf/model/sub_modules/table/__init__.py
Normal file
0
magic_pdf/model/sub_modules/table/__init__.py
Normal file
14
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py
Normal file
14
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import numpy as np
|
||||
from rapid_table import RapidTable
|
||||
from rapidocr_paddle import RapidOCR
|
||||
|
||||
|
||||
class RapidTableModel(object):
|
||||
def __init__(self):
|
||||
self.table_model = RapidTable()
|
||||
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
||||
|
||||
def predict(self, image):
|
||||
ocr_result, _ = self.ocr_engine(np.asarray(image))
|
||||
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
|
||||
return html_code, table_cell_bboxes, elapse
|
||||
@@ -1,8 +1,8 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
from struct_eqtable import build_model
|
||||
|
||||
from magic_pdf.model.sub_modules.table.table_utils import minify_html
|
||||
|
||||
|
||||
class StructTableModel:
|
||||
def __init__(self, model_path, max_new_tokens=1024, max_time=60):
|
||||
@@ -31,15 +31,7 @@ class StructTableModel:
|
||||
)
|
||||
|
||||
if output_format == "html":
|
||||
results = [self.minify_html(html) for html in results]
|
||||
results = [minify_html(html) for html in results]
|
||||
|
||||
return results
|
||||
|
||||
def minify_html(self, html):
|
||||
# 移除多余的空白字符
|
||||
html = re.sub(r'\s+', ' ', html)
|
||||
# 移除行尾的空白字符
|
||||
html = re.sub(r'\s*>\s*', '>', html)
|
||||
# 移除标签前的空白字符
|
||||
html = re.sub(r'\s*<\s*', '<', html)
|
||||
return html.strip()
|
||||
11
magic_pdf/model/sub_modules/table/table_utils.py
Normal file
11
magic_pdf/model/sub_modules/table/table_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import re
|
||||
|
||||
|
||||
def minify_html(html):
|
||||
# 移除多余的空白字符
|
||||
html = re.sub(r'\s+', ' ', html)
|
||||
# 移除行尾的空白字符
|
||||
html = re.sub(r'\s*>\s*', '>', html)
|
||||
# 移除标签前的空白字符
|
||||
html = re.sub(r'\s*<\s*', '<', html)
|
||||
return html.strip()
|
||||
@@ -7,7 +7,7 @@ from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ppTableModel(object):
|
||||
class TableMasterPaddleModel(object):
|
||||
"""
|
||||
This class is responsible for converting image of table into HTML format using a pre-trained model.
|
||||
|
||||
@@ -164,8 +164,8 @@ class ModelSingleton:
|
||||
|
||||
|
||||
def do_predict(boxes: List[List[int]], model) -> List[int]:
|
||||
from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
|
||||
prepare_inputs)
|
||||
from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (boxes2inputs, parse_logits,
|
||||
prepare_inputs)
|
||||
|
||||
inputs = boxes2inputs(boxes)
|
||||
inputs = prepare_inputs(inputs, model)
|
||||
@@ -206,7 +206,7 @@ def cal_block_index(fix_blocks, sorted_bboxes):
|
||||
del block['real_lines']
|
||||
|
||||
import numpy as np
|
||||
from magic_pdf.model.v3.xycut import recursive_xy_cut
|
||||
from magic_pdf.model.sub_modules.reading_oreder.layoutreader.xycut import recursive_xy_cut
|
||||
|
||||
random_boxes = np.array(block_bboxes)
|
||||
np.random.shuffle(random_boxes)
|
||||
|
||||
1
setup.py
1
setup.py
@@ -49,6 +49,7 @@ if __name__ == '__main__':
|
||||
"doclayout_yolo==0.0.2", # doclayout_yolo
|
||||
"rapidocr-paddle", # rapidocr-paddle
|
||||
"rapid_table", # rapid_table
|
||||
"PyYAML", # yaml
|
||||
"detectron2"
|
||||
],
|
||||
},
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
from PIL import Image
|
||||
from lxml import etree
|
||||
|
||||
from magic_pdf.model.ppTableModel import ppTableModel
|
||||
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
|
||||
|
||||
|
||||
class TestppTableModel(unittest.TestCase):
|
||||
@@ -11,7 +11,7 @@ class TestppTableModel(unittest.TestCase):
|
||||
# 修改table模型路径
|
||||
config = {"device": "cuda",
|
||||
"model_dir": "/home/quyuan/.cache/modelscope/hub/opendatalab/PDF-Extract-Kit/models/TabRec/TableMaster"}
|
||||
table_model = ppTableModel(config)
|
||||
table_model = TableMasterPaddleModel(config)
|
||||
res = table_model.img2html(img)
|
||||
# 验证生成的 HTML 是否符合预期
|
||||
parser = etree.HTMLParser()
|
||||
|
||||
Reference in New Issue
Block a user