feat(language-detection): improve language detection accuracy for specific languages

- Add separate models for Chinese/Japanese and English/French/German detection
- Implement mode-based detection to use appropriate models for different languages
- Update language detection process to use higher DPI for better accuracy
- Modify model initialization and prediction logic to support new language-specific models
This commit is contained in:
myhloli
2025-01-08 19:18:33 +08:00
parent 3fcac5efad
commit 356cb1f2de
5 changed files with 69 additions and 22 deletions

View File

@@ -51,6 +51,7 @@ magic-pdf --help
## 已知问题
- paddleocr使用内嵌onnx模型支持中英文ocr不支持其他语言ocr
- paddleocr使用内嵌onnx模型在默认语言配置下能以较快速度对中英文进行识别
- 自定义lang参数时paddleocr速度会存在明显下降情况
- layout模型使用layoutlmv3时会发生间歇性崩溃建议使用默认配置的doclayout_yolo模型
- 表格解析仅适配了rapid_table模型其他模型可能会无法使用

View File

@@ -12,7 +12,7 @@ from magic_pdf.data.utils import load_images_from_pdf
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel, LangDetectMode
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
@@ -59,15 +59,21 @@ def get_text_images(simple_images):
def auto_detect_lang(pdf_bytes: bytes):
sample_docs = extract_pages(pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=96)
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
text_images = get_text_images(simple_images)
local_models_dir, device, configs = get_model_config()
# 用yolo11做语言分类
langdetect_model_weights = str(
langdetect_model_weights_dir = str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
)
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights, device)
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights_dir, device)
lang = langdetect_model.do_detect(text_images)
if lang in ["ch", "japan"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.CH_JP)
elif lang in ["en", "fr", "german"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.EN_FR_GE)
return lang

View File

@@ -1,7 +1,9 @@
# Copyright (c) Opendatalab. All rights reserved.
import os
from collections import Counter
from uuid import uuid4
import torch
from PIL import Image
from loguru import logger
from ultralytics import YOLO
@@ -17,6 +19,11 @@ language_dict = {
"ru": "俄语"
}
class LangDetectMode:
BASE = "base"
CH_JP = "ch_jp"
EN_FR_GE = "en_fr_ge"
def split_images(image, result_images=None):
"""
@@ -83,11 +90,25 @@ def resize_images_to_224(image):
class YOLOv11LangDetModel(object):
def __init__(self, weight, device):
self.model = YOLO(weight)
self.device = device
def __init__(self, langdetect_model_weights_dir, device):
langdetect_model_base_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ft.pt')
)
langdetect_model_ch_jp_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ch_jp.pt')
)
langdetect_model_en_fr_ge_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_en_fr_ge.pt')
)
self.model = YOLO(langdetect_model_base_weight)
self.ch_jp_model = YOLO(langdetect_model_ch_jp_weight)
self.en_fr_ge_model = YOLO(langdetect_model_en_fr_ge_weight)
def do_detect(self, images: list):
if str(device).startswith("npu"):
self.device = torch.device(device)
else:
self.device = device
def do_detect(self, images: list, mode=LangDetectMode.BASE):
all_images = []
for image in images:
width, height = image.size
@@ -98,7 +119,7 @@ class YOLOv11LangDetModel(object):
for temp_image in temp_images:
all_images.append(resize_images_to_224(temp_image))
images_lang_res = self.batch_predict(all_images, batch_size=8)
images_lang_res = self.batch_predict(all_images, batch_size=8, mode=mode)
logger.info(f"images_lang_res: {images_lang_res}")
if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res)
@@ -107,20 +128,39 @@ class YOLOv11LangDetModel(object):
language = None
return language
def predict(self, image, mode=LangDetectMode.BASE):
def predict(self, image):
results = self.model.predict(image, verbose=False, device=self.device)
if mode == LangDetectMode.BASE:
model = self.model
elif mode == LangDetectMode.CH_JP:
model = self.ch_jp_model
elif mode == LangDetectMode.EN_FR_GE:
model = self.en_fr_ge_model
else:
model = self.model
results = model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
predicted_class_name = model.names[predicted_class_id]
return predicted_class_name
def batch_predict(self, images: list, batch_size: int) -> list:
def batch_predict(self, images: list, batch_size: int, mode=LangDetectMode.BASE) -> list:
images_lang_res = []
if mode == LangDetectMode.BASE:
model = self.model
elif mode == LangDetectMode.CH_JP:
model = self.ch_jp_model
elif mode == LangDetectMode.EN_FR_GE:
model = self.en_fr_ge_model
else:
model = self.model
for index in range(0, len(images), batch_size):
lang_res = [
image_res.cpu()
for image_res in self.model.predict(
for image_res in model.predict(
images[index: index + batch_size],
verbose = False,
device=self.device,
@@ -128,7 +168,7 @@ class YOLOv11LangDetModel(object):
]
for res in lang_res:
predicted_class_id = int(res.probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
predicted_class_name = model.names[predicted_class_id]
images_lang_res.append(predicted_class_name)
return images_lang_res

View File

@@ -21,7 +21,7 @@ class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
@@ -94,7 +94,7 @@ class ModifiedPaddleOCR(PaddleOCR):
ocr_res = []
for img in imgs:
img = preprocess_image(img)
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
@@ -124,7 +124,7 @@ class ModifiedPaddleOCR(PaddleOCR):
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
@@ -142,7 +142,7 @@ class ModifiedPaddleOCR(PaddleOCR):
start = time.time()
ori_im = img.copy()
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
@@ -183,7 +183,7 @@ class ModifiedPaddleOCR(PaddleOCR):
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse))
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)

View File

@@ -6,4 +6,4 @@ weights:
struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_cls_ft.pt
yolo_v11n_langdetect: LangDetect/YOLO