mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
refactor(langdetect): simplify language detection model and improve logging
- Remove LangDetectMode and related conditional logic - Use a single model weight for language detection - Add logging for language detection results - Update model initialization and prediction methods
This commit is contained in:
@@ -153,6 +153,7 @@ class PymuDocDataset(Dataset):
|
||||
logger.info(f"lang: {lang}, detect_lang: {self._lang}")
|
||||
else:
|
||||
self._lang = lang
|
||||
logger.info(f"lang: {lang}")
|
||||
def __len__(self) -> int:
|
||||
"""The page number of the pdf."""
|
||||
return len(self._records)
|
||||
|
||||
@@ -12,7 +12,6 @@ from magic_pdf.data.utils import load_images_from_pdf
|
||||
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
|
||||
from magic_pdf.libs.pdf_check import extract_pages
|
||||
from magic_pdf.model.model_list import AtomicModel
|
||||
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import LangDetectMode
|
||||
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
||||
|
||||
|
||||
@@ -63,11 +62,6 @@ def auto_detect_lang(pdf_bytes: bytes):
|
||||
text_images = get_text_images(simple_images)
|
||||
langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
|
||||
lang = langdetect_model.do_detect(text_images)
|
||||
if lang in ["ch", "japan"]:
|
||||
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.CH_JP)
|
||||
elif lang in ["en", "fr", "german"]:
|
||||
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.EN_FR_GE)
|
||||
|
||||
return lang
|
||||
|
||||
|
||||
@@ -79,7 +73,7 @@ def model_init(model_name: str):
|
||||
model = atom_model_manager.get_atom_model(
|
||||
atom_model_name=AtomicModel.LangDetect,
|
||||
langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
|
||||
langdetect_model_weights_dir=str(
|
||||
langdetect_model_weight=str(
|
||||
os.path.join(
|
||||
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) Opendatalab. All rights reserved.
|
||||
import os
|
||||
from collections import Counter
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -19,11 +18,6 @@ language_dict = {
|
||||
"ru": "俄语"
|
||||
}
|
||||
|
||||
class LangDetectMode:
|
||||
BASE = "base"
|
||||
CH_JP = "ch_jp"
|
||||
EN_FR_GE = "en_fr_ge"
|
||||
|
||||
|
||||
def split_images(image, result_images=None):
|
||||
"""
|
||||
@@ -90,25 +84,15 @@ def resize_images_to_224(image):
|
||||
|
||||
|
||||
class YOLOv11LangDetModel(object):
|
||||
def __init__(self, langdetect_model_weights_dir, device):
|
||||
langdetect_model_base_weight = str(
|
||||
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ft.pt')
|
||||
)
|
||||
langdetect_model_ch_jp_weight = str(
|
||||
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ch_jp.pt')
|
||||
)
|
||||
langdetect_model_en_fr_ge_weight = str(
|
||||
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_en_fr_ge.pt')
|
||||
)
|
||||
self.model = YOLO(langdetect_model_base_weight)
|
||||
self.ch_jp_model = YOLO(langdetect_model_ch_jp_weight)
|
||||
self.en_fr_ge_model = YOLO(langdetect_model_en_fr_ge_weight)
|
||||
def __init__(self, langdetect_model_weight, device):
|
||||
|
||||
self.model = YOLO(langdetect_model_weight)
|
||||
|
||||
if str(device).startswith("npu"):
|
||||
self.device = torch.device(device)
|
||||
else:
|
||||
self.device = device
|
||||
def do_detect(self, images: list, mode=LangDetectMode.BASE):
|
||||
def do_detect(self, images: list):
|
||||
all_images = []
|
||||
for image in images:
|
||||
width, height = image.size
|
||||
@@ -119,8 +103,8 @@ class YOLOv11LangDetModel(object):
|
||||
for temp_image in temp_images:
|
||||
all_images.append(resize_images_to_224(temp_image))
|
||||
|
||||
images_lang_res = self.batch_predict(all_images, batch_size=8, mode=mode)
|
||||
logger.info(f"images_lang_res: {images_lang_res}")
|
||||
images_lang_res = self.batch_predict(all_images, batch_size=8)
|
||||
# logger.info(f"images_lang_res: {images_lang_res}")
|
||||
if len(images_lang_res) > 0:
|
||||
count_dict = Counter(images_lang_res)
|
||||
language = max(count_dict, key=count_dict.get)
|
||||
@@ -128,39 +112,20 @@ class YOLOv11LangDetModel(object):
|
||||
language = None
|
||||
return language
|
||||
|
||||
def predict(self, image, mode=LangDetectMode.BASE):
|
||||
|
||||
if mode == LangDetectMode.BASE:
|
||||
model = self.model
|
||||
elif mode == LangDetectMode.CH_JP:
|
||||
model = self.ch_jp_model
|
||||
elif mode == LangDetectMode.EN_FR_GE:
|
||||
model = self.en_fr_ge_model
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
results = model.predict(image, verbose=False, device=self.device)
|
||||
def predict(self, image):
|
||||
results = self.model.predict(image, verbose=False, device=self.device)
|
||||
predicted_class_id = int(results[0].probs.top1)
|
||||
predicted_class_name = model.names[predicted_class_id]
|
||||
predicted_class_name = self.model.names[predicted_class_id]
|
||||
return predicted_class_name
|
||||
|
||||
|
||||
def batch_predict(self, images: list, batch_size: int, mode=LangDetectMode.BASE) -> list:
|
||||
def batch_predict(self, images: list, batch_size: int) -> list:
|
||||
images_lang_res = []
|
||||
|
||||
if mode == LangDetectMode.BASE:
|
||||
model = self.model
|
||||
elif mode == LangDetectMode.CH_JP:
|
||||
model = self.ch_jp_model
|
||||
elif mode == LangDetectMode.EN_FR_GE:
|
||||
model = self.en_fr_ge_model
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
for index in range(0, len(images), batch_size):
|
||||
lang_res = [
|
||||
image_res.cpu()
|
||||
for image_res in model.predict(
|
||||
for image_res in self.model.predict(
|
||||
images[index: index + batch_size],
|
||||
verbose = False,
|
||||
device=self.device,
|
||||
@@ -168,7 +133,7 @@ class YOLOv11LangDetModel(object):
|
||||
]
|
||||
for res in lang_res:
|
||||
predicted_class_id = int(res.probs.top1)
|
||||
predicted_class_name = model.names[predicted_class_id]
|
||||
predicted_class_name = self.model.names[predicted_class_id]
|
||||
images_lang_res.append(predicted_class_name)
|
||||
|
||||
return images_lang_res
|
||||
@@ -63,10 +63,10 @@ def doclayout_yolo_model_init(weight, device='cpu'):
|
||||
return model
|
||||
|
||||
|
||||
def langdetect_model_init(langdetect_model_weights_dir, device='cpu'):
|
||||
def langdetect_model_init(langdetect_model_weight, device='cpu'):
|
||||
if str(device).startswith("npu"):
|
||||
device = torch.device(device)
|
||||
model = YOLOv11LangDetModel(langdetect_model_weights_dir, device)
|
||||
model = YOLOv11LangDetModel(langdetect_model_weight, device)
|
||||
return model
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ def atom_model_init(model_name: str, **kwargs):
|
||||
elif model_name == AtomicModel.LangDetect:
|
||||
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
|
||||
atom_model = langdetect_model_init(
|
||||
kwargs.get('langdetect_model_weights_dir'),
|
||||
kwargs.get('langdetect_model_weight'),
|
||||
kwargs.get('device')
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -6,4 +6,4 @@ weights:
|
||||
struct_eqtable: TabRec/StructEqTable
|
||||
tablemaster: TabRec/TableMaster
|
||||
rapid_table: TabRec/RapidTable
|
||||
yolo_v11n_langdetect: LangDetect/YOLO
|
||||
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_ft.pt
|
||||
Reference in New Issue
Block a user