mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
Refactor: add pp_doclayoutv2 as layout model
This commit is contained in:
@@ -5,6 +5,7 @@ from loguru import logger
|
||||
|
||||
from .model_list import AtomicModel
|
||||
from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
|
||||
from ...model.layout.pp_doclayoutv2 import PPDocLayoutV2LayoutModel
|
||||
from ...model.mfd.yolo_v8 import YOLOv8MFDModel
|
||||
from ...model.mfr.unimernet.Unimernet import UnimernetModel
|
||||
from ...model.mfr.pp_formulanet_plus_m.predict_formula import FormulaRecognizer
|
||||
@@ -28,6 +29,27 @@ else:
|
||||
MFR_MODEL = "unimernet_small"
|
||||
|
||||
|
||||
def get_layout_model_name() -> str:
|
||||
return os.getenv("MINERU_LAYOUT_MODEL", "doclayout_yolo").strip().lower()
|
||||
|
||||
|
||||
def get_layout_model_weight(layout_model_name: str) -> str:
|
||||
override_path = os.getenv("MINERU_LAYOUT_MODEL_PATH")
|
||||
if override_path:
|
||||
return override_path
|
||||
|
||||
if layout_model_name == "doclayout_yolo":
|
||||
return str(
|
||||
os.path.join(
|
||||
auto_download_and_get_model_root_path(ModelPath.doclayout_yolo),
|
||||
ModelPath.doclayout_yolo,
|
||||
)
|
||||
)
|
||||
if layout_model_name == "pp_doclayout_v2":
|
||||
return os.getenv("MINERU_PP_DOCLAYOUT_V2_MODEL", ModelPath.pp_doclayout_v2)
|
||||
raise ValueError(f"Unsupported layout model: {layout_model_name}")
|
||||
|
||||
|
||||
def img_orientation_cls_model_init():
|
||||
atom_model_manager = AtomModelSingleton()
|
||||
ocr_engine = atom_model_manager.get_atom_model(
|
||||
@@ -95,6 +117,13 @@ def doclayout_yolo_model_init(weight, device='cpu'):
|
||||
model = DocLayoutYOLOModel(weight, device)
|
||||
return model
|
||||
|
||||
|
||||
def pp_doclayout_v2_model_init(weight, device='cpu'):
|
||||
if str(device).startswith('npu'):
|
||||
device = torch.device(device)
|
||||
model = PPDocLayoutV2LayoutModel(weight, device)
|
||||
return model
|
||||
|
||||
def ocr_model_init(det_db_box_thresh=0.3,
|
||||
lang=None,
|
||||
det_db_unclip_ratio=1.8,
|
||||
@@ -144,6 +173,13 @@ class AtomModelSingleton:
|
||||
kwargs.get('det_db_unclip_ratio', 1.8),
|
||||
kwargs.get('enable_merge_det_boxes', True)
|
||||
)
|
||||
elif atom_model_name == AtomicModel.Layout:
|
||||
key = (
|
||||
atom_model_name,
|
||||
kwargs.get('layout_model_name', 'doclayout_yolo'),
|
||||
kwargs.get('layout_model_weight'),
|
||||
kwargs.get('device'),
|
||||
)
|
||||
else:
|
||||
key = atom_model_name
|
||||
|
||||
@@ -154,10 +190,21 @@ class AtomModelSingleton:
|
||||
def atom_model_init(model_name: str, **kwargs):
|
||||
atom_model = None
|
||||
if model_name == AtomicModel.Layout:
|
||||
atom_model = doclayout_yolo_model_init(
|
||||
kwargs.get('doclayout_yolo_weights'),
|
||||
kwargs.get('device')
|
||||
)
|
||||
layout_model_name = kwargs.get('layout_model_name', 'doclayout_yolo')
|
||||
layout_model_weight = kwargs.get('layout_model_weight')
|
||||
if layout_model_name == 'doclayout_yolo':
|
||||
atom_model = doclayout_yolo_model_init(
|
||||
layout_model_weight or kwargs.get('doclayout_yolo_weights'),
|
||||
kwargs.get('device')
|
||||
)
|
||||
elif layout_model_name == 'pp_doclayout_v2':
|
||||
atom_model = pp_doclayout_v2_model_init(
|
||||
layout_model_weight or kwargs.get('pp_doclayout_v2_weights'),
|
||||
kwargs.get('device')
|
||||
)
|
||||
else:
|
||||
logger.error(f'layout model name not allow: {layout_model_name}')
|
||||
exit(1)
|
||||
elif model_name == AtomicModel.MFD:
|
||||
atom_model = mfd_model_init(
|
||||
kwargs.get('mfd_weights'),
|
||||
@@ -210,6 +257,8 @@ class MineruPipelineModel:
|
||||
'DocAnalysis init, this may take some times......'
|
||||
)
|
||||
atom_model_manager = AtomModelSingleton()
|
||||
layout_model_name = get_layout_model_name()
|
||||
layout_model_weight = get_layout_model_weight(layout_model_name)
|
||||
|
||||
if self.apply_formula:
|
||||
# 初始化公式检测模型
|
||||
@@ -237,11 +286,11 @@ class MineruPipelineModel:
|
||||
)
|
||||
|
||||
# 初始化layout模型
|
||||
logger.info(f'Using layout model: {layout_model_name}')
|
||||
self.layout_model = atom_model_manager.get_atom_model(
|
||||
atom_model_name=AtomicModel.Layout,
|
||||
doclayout_yolo_weights=str(
|
||||
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
|
||||
),
|
||||
layout_model_name=layout_model_name,
|
||||
layout_model_weight=layout_model_weight,
|
||||
device=self.device,
|
||||
)
|
||||
# 初始化ocr
|
||||
@@ -369,4 +418,4 @@ class MineruHybridModel:
|
||||
atom_model_name=AtomicModel.MFR,
|
||||
mfr_weight_dir=str(os.path.join(auto_download_and_get_model_root_path(mfr_model_path), mfr_model_path)),
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
|
||||
1159
mineru/model/layout/pp_doclayoutv2.py
Normal file
1159
mineru/model/layout/pp_doclayoutv2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -100,6 +100,7 @@ class ModelPath:
|
||||
pipeline_root_modelscope = "OpenDataLab/PDF-Extract-Kit-1.0"
|
||||
pipeline_root_hf = "opendatalab/PDF-Extract-Kit-1.0"
|
||||
doclayout_yolo = "models/Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt"
|
||||
pp_doclayout_v2 = "models/Layout/PP-DocLayoutV2"
|
||||
yolo_v8_mfd = "models/MFD/YOLO/yolo_v8_ft.pt"
|
||||
unimernet_small = "models/MFR/unimernet_hf_small_2503"
|
||||
pp_formulanet_plus_m = "models/MFR/pp_formulanet_plus_m"
|
||||
|
||||
@@ -85,6 +85,7 @@ pipeline = [
|
||||
"torch>=2.6.0,<3",
|
||||
"torchvision",
|
||||
"transformers>=4.49.0,!=4.51.0,<5.0.0",
|
||||
"safetensors>=0.4.0,<1",
|
||||
"onnxruntime>1.17.0",
|
||||
]
|
||||
api = [
|
||||
@@ -169,4 +170,4 @@ exclude_also = [
|
||||
'if TYPE_CHECKING:',
|
||||
'class .*\bProtocol\):',
|
||||
'@(abc\.)?abstractmethod',
|
||||
]
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user