Refactor: add pp_doclayoutv2 as layout model

This commit is contained in:
myhloli
2026-03-09 16:12:33 +08:00
parent 323e0092e0
commit 65fb7495fd
4 changed files with 1219 additions and 9 deletions

View File

@@ -5,6 +5,7 @@ from loguru import logger
from .model_list import AtomicModel
from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
from ...model.layout.pp_doclayoutv2 import PPDocLayoutV2LayoutModel
from ...model.mfd.yolo_v8 import YOLOv8MFDModel
from ...model.mfr.unimernet.Unimernet import UnimernetModel
from ...model.mfr.pp_formulanet_plus_m.predict_formula import FormulaRecognizer
@@ -28,6 +29,27 @@ else:
MFR_MODEL = "unimernet_small"
def get_layout_model_name() -> str:
return os.getenv("MINERU_LAYOUT_MODEL", "doclayout_yolo").strip().lower()
def get_layout_model_weight(layout_model_name: str) -> str:
override_path = os.getenv("MINERU_LAYOUT_MODEL_PATH")
if override_path:
return override_path
if layout_model_name == "doclayout_yolo":
return str(
os.path.join(
auto_download_and_get_model_root_path(ModelPath.doclayout_yolo),
ModelPath.doclayout_yolo,
)
)
if layout_model_name == "pp_doclayout_v2":
return os.getenv("MINERU_PP_DOCLAYOUT_V2_MODEL", ModelPath.pp_doclayout_v2)
raise ValueError(f"Unsupported layout model: {layout_model_name}")
def img_orientation_cls_model_init():
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
@@ -95,6 +117,13 @@ def doclayout_yolo_model_init(weight, device='cpu'):
model = DocLayoutYOLOModel(weight, device)
return model
def pp_doclayout_v2_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
model = PPDocLayoutV2LayoutModel(weight, device)
return model
def ocr_model_init(det_db_box_thresh=0.3,
lang=None,
det_db_unclip_ratio=1.8,
@@ -144,6 +173,13 @@ class AtomModelSingleton:
kwargs.get('det_db_unclip_ratio', 1.8),
kwargs.get('enable_merge_det_boxes', True)
)
elif atom_model_name == AtomicModel.Layout:
key = (
atom_model_name,
kwargs.get('layout_model_name', 'doclayout_yolo'),
kwargs.get('layout_model_weight'),
kwargs.get('device'),
)
else:
key = atom_model_name
@@ -154,10 +190,21 @@ class AtomModelSingleton:
def atom_model_init(model_name: str, **kwargs):
atom_model = None
if model_name == AtomicModel.Layout:
atom_model = doclayout_yolo_model_init(
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
layout_model_name = kwargs.get('layout_model_name', 'doclayout_yolo')
layout_model_weight = kwargs.get('layout_model_weight')
if layout_model_name == 'doclayout_yolo':
atom_model = doclayout_yolo_model_init(
layout_model_weight or kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
elif layout_model_name == 'pp_doclayout_v2':
atom_model = pp_doclayout_v2_model_init(
layout_model_weight or kwargs.get('pp_doclayout_v2_weights'),
kwargs.get('device')
)
else:
logger.error(f'layout model name not allow: {layout_model_name}')
exit(1)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get('mfd_weights'),
@@ -210,6 +257,8 @@ class MineruPipelineModel:
'DocAnalysis init, this may take some times......'
)
atom_model_manager = AtomModelSingleton()
layout_model_name = get_layout_model_name()
layout_model_weight = get_layout_model_weight(layout_model_name)
if self.apply_formula:
# 初始化公式检测模型
@@ -237,11 +286,11 @@ class MineruPipelineModel:
)
# 初始化layout模型
logger.info(f'Using layout model: {layout_model_name}')
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
),
layout_model_name=layout_model_name,
layout_model_weight=layout_model_weight,
device=self.device,
)
# 初始化ocr
@@ -369,4 +418,4 @@ class MineruHybridModel:
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=str(os.path.join(auto_download_and_get_model_root_path(mfr_model_path), mfr_model_path)),
device=self.device,
)
)

File diff suppressed because it is too large Load Diff

View File

@@ -100,6 +100,7 @@ class ModelPath:
pipeline_root_modelscope = "OpenDataLab/PDF-Extract-Kit-1.0"
pipeline_root_hf = "opendatalab/PDF-Extract-Kit-1.0"
doclayout_yolo = "models/Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt"
pp_doclayout_v2 = "models/Layout/PP-DocLayoutV2"
yolo_v8_mfd = "models/MFD/YOLO/yolo_v8_ft.pt"
unimernet_small = "models/MFR/unimernet_hf_small_2503"
pp_formulanet_plus_m = "models/MFR/pp_formulanet_plus_m"

View File

@@ -85,6 +85,7 @@ pipeline = [
"torch>=2.6.0,<3",
"torchvision",
"transformers>=4.49.0,!=4.51.0,<5.0.0",
"safetensors>=0.4.0,<1",
"onnxruntime>1.17.0",
]
api = [
@@ -169,4 +170,4 @@ exclude_also = [
'if TYPE_CHECKING:',
'class .*\bProtocol\):',
'@(abc\.)?abstractmethod',
]
]