refactor: implement dynamic batch ratio based on GPU memory and environment variable

This commit is contained in:
myhloli
2025-12-23 17:00:02 +08:00
parent 408d94ed58
commit 447ffcd32f

View File

@@ -4,6 +4,7 @@ from collections import defaultdict
import cv2
import numpy as np
from loguru import logger
from mineru_vl_utils import MinerUClient
from mineru_vl_utils.structs import BlockType
from tqdm import tqdm
@@ -12,8 +13,9 @@ from mineru.backend.hybrid.hybrid_model_output_to_middle_json import result_to_m
from mineru.backend.pipeline.model_init import HybridModelSingleton
from mineru.backend.vlm.vlm_analyze import ModelSingleton
from mineru.data.data_reader_writer import DataWriter
from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import ImageType, NotExtractType
from mineru.utils.model_utils import crop_img
from mineru.utils.model_utils import crop_img, get_vram, clean_memory
from mineru.utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, sorted_boxes, merge_det_boxes, \
update_det_boxes, OcrConfidence
from mineru.utils.pdf_classify import classify
@@ -203,7 +205,7 @@ def _process_ocr_and_formulas(
inline_formula_enable,
_ocr_enable,
_vlm_ocr_enable,
batch_radio: int = 2,
batch_radio: int = 1,
):
"""处理OCR和公式识别"""
@@ -317,6 +319,36 @@ def _normalize_bbox(
normalize_poly_to_bbox(ocr_res, page_width, page_height)
def get_batch_ratio(device):
"""
根据显存大小或环境变量获取 batch ratio
"""
# 1. 优先尝试从环境变量获取
env_val = os.getenv("MINERU_HYBRID_BATCH_RATIO")
if env_val:
try:
batch_ratio = int(env_val)
logger.info(f"hybrid batch ratio (from env): {batch_ratio}")
return batch_ratio
except ValueError as e:
logger.warning(f"Invalid MINERU_HYBRID_BATCH_RATIO value: {env_val}, switching to auto mode. Error: {e}")
# 2. 根据显存自动推断
gpu_memory = get_vram(device)
if gpu_memory >= 24:
batch_ratio = 8
elif gpu_memory >= 16:
batch_ratio = 4
elif gpu_memory >= 8:
batch_ratio = 2
else:
batch_ratio = 1
logger.info(f"hybrid batch ratio (auto, vram={gpu_memory}GB): {batch_ratio}")
return batch_ratio
def doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
@@ -335,6 +367,9 @@ def doc_analyze(
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
device = get_device()
batch_ratio = get_batch_ratio(device)
_ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
_vlm_ocr_enable = False
if _ocr_enable and language in ["ch", "en"] and inline_formula_enable:
@@ -350,6 +385,7 @@ def doc_analyze(
inline_formula_enable,
_ocr_enable,
_vlm_ocr_enable,
batch_radio=batch_ratio,
)
_normalize_bbox(
@@ -369,6 +405,9 @@ def doc_analyze(
_vlm_ocr_enable,
hybrid_pipeline_model,
)
clean_memory(device)
return middle_json, results
@@ -390,6 +429,9 @@ async def aio_doc_analyze(
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
device = get_device()
batch_ratio = get_batch_ratio(device)
_ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
_vlm_ocr_enable = False
if _ocr_enable and language in ["ch", "en"] and inline_formula_enable:
@@ -405,6 +447,7 @@ async def aio_doc_analyze(
inline_formula_enable,
_ocr_enable,
_vlm_ocr_enable,
batch_radio=batch_ratio,
)
_normalize_bbox(
@@ -424,5 +467,8 @@ async def aio_doc_analyze(
_vlm_ocr_enable,
hybrid_pipeline_model,
)
clean_memory(device)
return middle_json, results