mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-04-04 23:18:37 +07:00
- Remove concurrency limit logic from app.py - Update model initialization process in various modules - Remove unused VRAM check for concurrency limit - Refactor OCR model initialization in pdf_extract_kit.py - Update txt_spans_extract_v2 function to use lang parameter instead of ocr_model
160 lines
5.3 KiB
Python
160 lines
5.3 KiB
Python
from loguru import logger
|
|
|
|
from magic_pdf.config.constants import MODEL_NAME
|
|
from magic_pdf.model.model_list import AtomicModel
|
|
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
|
|
DocLayoutYOLOModel
|
|
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
|
|
Layoutlmv3_Predictor
|
|
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
|
|
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
|
|
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
|
|
ModifiedPaddleOCR
|
|
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
|
|
RapidTableModel
|
|
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
|
|
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
|
|
StructTableModel
|
|
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
|
|
TableMasterPaddleModel
|
|
|
|
|
|
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
|
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
|
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
|
config = {
|
|
'model_dir': model_path,
|
|
'device': _device_
|
|
}
|
|
table_model = TableMasterPaddleModel(config)
|
|
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
|
table_model = RapidTableModel()
|
|
else:
|
|
logger.error('table model type not allow')
|
|
exit(1)
|
|
|
|
return table_model
|
|
|
|
|
|
def mfd_model_init(weight, device='cpu'):
|
|
mfd_model = YOLOv8MFDModel(weight, device)
|
|
return mfd_model
|
|
|
|
|
|
def mfr_model_init(weight_dir, cfg_path, device='cpu'):
|
|
mfr_model = UnimernetModel(weight_dir, cfg_path, device)
|
|
return mfr_model
|
|
|
|
|
|
def layout_model_init(weight, config_file, device):
|
|
model = Layoutlmv3_Predictor(weight, config_file, device)
|
|
return model
|
|
|
|
|
|
def doclayout_yolo_model_init(weight, device='cpu'):
|
|
model = DocLayoutYOLOModel(weight, device)
|
|
return model
|
|
|
|
|
|
def ocr_model_init(show_log: bool = False,
|
|
det_db_box_thresh=0.3,
|
|
lang=None,
|
|
use_dilation=True,
|
|
det_db_unclip_ratio=1.8,
|
|
):
|
|
if lang is not None and lang != '':
|
|
model = ModifiedPaddleOCR(
|
|
show_log=show_log,
|
|
det_db_box_thresh=det_db_box_thresh,
|
|
lang=lang,
|
|
use_dilation=use_dilation,
|
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
|
)
|
|
else:
|
|
model = ModifiedPaddleOCR(
|
|
show_log=show_log,
|
|
det_db_box_thresh=det_db_box_thresh,
|
|
use_dilation=use_dilation,
|
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
|
# use_angle_cls=True,
|
|
)
|
|
return model
|
|
|
|
|
|
class AtomModelSingleton:
|
|
_instance = None
|
|
_models = {}
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def get_atom_model(self, atom_model_name: str, **kwargs):
|
|
|
|
lang = kwargs.get('lang', None)
|
|
layout_model_name = kwargs.get('layout_model_name', None)
|
|
table_model_name = kwargs.get('table_model_name', None)
|
|
|
|
if atom_model_name in [AtomicModel.OCR]:
|
|
key = (atom_model_name, lang)
|
|
elif atom_model_name in [AtomicModel.Layout]:
|
|
key = (atom_model_name, layout_model_name)
|
|
elif atom_model_name in [AtomicModel.Table]:
|
|
key = (atom_model_name, table_model_name)
|
|
else:
|
|
key = atom_model_name
|
|
|
|
if key not in self._models:
|
|
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
|
return self._models[key]
|
|
|
|
def atom_model_init(model_name: str, **kwargs):
|
|
atom_model = None
|
|
if model_name == AtomicModel.Layout:
|
|
if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
|
|
atom_model = layout_model_init(
|
|
kwargs.get('layout_weights'),
|
|
kwargs.get('layout_config_file'),
|
|
kwargs.get('device')
|
|
)
|
|
elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
|
|
atom_model = doclayout_yolo_model_init(
|
|
kwargs.get('doclayout_yolo_weights'),
|
|
kwargs.get('device')
|
|
)
|
|
elif model_name == AtomicModel.MFD:
|
|
atom_model = mfd_model_init(
|
|
kwargs.get('mfd_weights'),
|
|
kwargs.get('device')
|
|
)
|
|
elif model_name == AtomicModel.MFR:
|
|
atom_model = mfr_model_init(
|
|
kwargs.get('mfr_weight_dir'),
|
|
kwargs.get('mfr_cfg_path'),
|
|
kwargs.get('device')
|
|
)
|
|
elif model_name == AtomicModel.OCR:
|
|
atom_model = ocr_model_init(
|
|
kwargs.get('ocr_show_log'),
|
|
kwargs.get('det_db_box_thresh'),
|
|
kwargs.get('lang'),
|
|
)
|
|
elif model_name == AtomicModel.Table:
|
|
atom_model = table_model_init(
|
|
kwargs.get('table_model_name'),
|
|
kwargs.get('table_model_path'),
|
|
kwargs.get('table_max_time'),
|
|
kwargs.get('device')
|
|
)
|
|
else:
|
|
logger.error('model name not allow')
|
|
exit(1)
|
|
|
|
if atom_model is None:
|
|
logger.error('model init failed')
|
|
exit(1)
|
|
else:
|
|
return atom_model
|