mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 02:58:54 +07:00
refactor: enhance batch ratio calculation based on GPU compute capability
This commit is contained in:
@@ -5,10 +5,17 @@ from collections import defaultdict
|
||||
import cv2
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from mineru_vl_utils import MinerUClient
|
||||
from mineru_vl_utils.structs import BlockType
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from mineru.backend.hybrid.hybrid_model_output_to_middle_json import result_to_middle_json
|
||||
from mineru.backend.pipeline.model_init import HybridModelSingleton
|
||||
from mineru.backend.vlm.vlm_analyze import ModelSingleton
|
||||
@@ -362,6 +369,15 @@ def get_batch_ratio(device):
|
||||
else:
|
||||
batch_ratio = 1
|
||||
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
# 正确计算Compute Capability
|
||||
compute_capability = f"{major}.{minor}"
|
||||
elif hasattr(torch, 'npu') and torch.npu.is_available():
|
||||
compute_capability = "8.0"
|
||||
if version.parse(compute_capability) < version.parse("8.0"):
|
||||
batch_ratio = max(1, batch_ratio // 2)
|
||||
|
||||
logger.info(f"hybrid batch ratio (auto, vram={gpu_memory}GB): {batch_ratio}")
|
||||
return batch_ratio
|
||||
|
||||
|
||||
Reference in New Issue
Block a user