mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 02:58:54 +07:00
feat: integrate cudagraph_capture_sizes into vlm_analyze for MUSA devices
This commit is contained in:
@@ -100,20 +100,4 @@ def set_default_batch_size() -> int:
|
||||
except Exception as e:
|
||||
logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
|
||||
batch_size = 1
|
||||
return batch_size
|
||||
|
||||
|
||||
def set_compilation_config() -> dict:
|
||||
device = get_device()
|
||||
compilation_config = {}
|
||||
if str(device).startswith('musa'):
|
||||
try:
|
||||
import torch
|
||||
if torch.musa.is_available():
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 30],
|
||||
"simple_cuda_graph": True
|
||||
}
|
||||
except Exception as e:
|
||||
pass
|
||||
return compilation_config
|
||||
return batch_size
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
from loguru import logger
|
||||
|
||||
from .utils import enable_custom_logits_processors, set_default_gpu_memory_utilization, set_default_batch_size, \
|
||||
set_lmdeploy_backend, set_compilation_config
|
||||
set_lmdeploy_backend
|
||||
from .model_output_to_middle_json import result_to_middle_json
|
||||
from ...data.data_reader_writer import DataWriter
|
||||
from mineru.utils.pdf_image_tools import load_images_from_pdf
|
||||
@@ -100,6 +100,17 @@ class ModelSingleton:
|
||||
import vllm
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm to use the vllm-engine backend.")
|
||||
device = get_device()
|
||||
if device.startswith("musa"):
|
||||
import torch
|
||||
if torch.musa.is_available():
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 30],
|
||||
"simple_cuda_graph": True
|
||||
}
|
||||
block_size = 32
|
||||
kwargs["compilation_config"] = compilation_config
|
||||
kwargs["block_size"] = block_size
|
||||
|
||||
if "compilation_config" in kwargs:
|
||||
if isinstance(kwargs["compilation_config"], str):
|
||||
@@ -109,11 +120,6 @@ class ModelSingleton:
|
||||
logger.warning(
|
||||
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
|
||||
del kwargs["compilation_config"]
|
||||
else:
|
||||
compilation_config = set_compilation_config()
|
||||
if compilation_config:
|
||||
kwargs["compilation_config"] = compilation_config
|
||||
|
||||
if "gpu_memory_utilization" not in kwargs:
|
||||
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
|
||||
if "model" not in kwargs:
|
||||
@@ -129,6 +135,17 @@ class ModelSingleton:
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
|
||||
device = get_device()
|
||||
if device.startswith("musa"):
|
||||
import torch
|
||||
if torch.musa.is_available():
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [1,2,3,4,5,6,7,8,10,12,14,16,18,20,24,28,30],
|
||||
"simple_cuda_graph": True
|
||||
}
|
||||
block_size = 32
|
||||
kwargs["compilation_config"] = compilation_config
|
||||
kwargs["block_size"] = block_size
|
||||
|
||||
if "compilation_config" in kwargs:
|
||||
if isinstance(kwargs["compilation_config"], str):
|
||||
@@ -138,11 +155,6 @@ class ModelSingleton:
|
||||
logger.warning(
|
||||
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
|
||||
del kwargs["compilation_config"]
|
||||
else:
|
||||
compilation_config = set_compilation_config()
|
||||
if compilation_config:
|
||||
kwargs["compilation_config"] = compilation_config
|
||||
|
||||
if "gpu_memory_utilization" not in kwargs:
|
||||
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
|
||||
if "model" not in kwargs:
|
||||
|
||||
Reference in New Issue
Block a user