feat: integrate cudagraph_capture_sizes into vlm_analyze for MUSA devices

This commit is contained in:
myhloli
2026-01-22 18:52:53 +08:00
parent 294105c1b0
commit 5f7214bf2f
2 changed files with 24 additions and 28 deletions

View File

@@ -100,20 +100,4 @@ def set_default_batch_size() -> int:
except Exception as e:
logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
batch_size = 1
return batch_size
def set_compilation_config() -> dict:
device = get_device()
compilation_config = {}
if str(device).startswith('musa'):
try:
import torch
if torch.musa.is_available():
compilation_config = {
"cudagraph_capture_sizes": [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 30],
"simple_cuda_graph": True
}
except Exception as e:
pass
return compilation_config
return batch_size

View File

@@ -6,7 +6,7 @@ import json
from loguru import logger
from .utils import enable_custom_logits_processors, set_default_gpu_memory_utilization, set_default_batch_size, \
set_lmdeploy_backend, set_compilation_config
set_lmdeploy_backend
from .model_output_to_middle_json import result_to_middle_json
from ...data.data_reader_writer import DataWriter
from mineru.utils.pdf_image_tools import load_images_from_pdf
@@ -100,6 +100,17 @@ class ModelSingleton:
import vllm
except ImportError:
raise ImportError("Please install vllm to use the vllm-engine backend.")
device = get_device()
if device.startswith("musa"):
import torch
if torch.musa.is_available():
compilation_config = {
"cudagraph_capture_sizes": [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 30],
"simple_cuda_graph": True
}
block_size = 32
kwargs["compilation_config"] = compilation_config
kwargs["block_size"] = block_size
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], str):
@@ -109,11 +120,6 @@ class ModelSingleton:
logger.warning(
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
del kwargs["compilation_config"]
else:
compilation_config = set_compilation_config()
if compilation_config:
kwargs["compilation_config"] = compilation_config
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs:
@@ -129,6 +135,17 @@ class ModelSingleton:
from vllm.v1.engine.async_llm import AsyncLLM
except ImportError:
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
device = get_device()
if device.startswith("musa"):
import torch
if torch.musa.is_available():
compilation_config = {
"cudagraph_capture_sizes": [1,2,3,4,5,6,7,8,10,12,14,16,18,20,24,28,30],
"simple_cuda_graph": True
}
block_size = 32
kwargs["compilation_config"] = compilation_config
kwargs["block_size"] = block_size
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], str):
@@ -138,11 +155,6 @@ class ModelSingleton:
logger.warning(
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
del kwargs["compilation_config"]
else:
compilation_config = set_compilation_config()
if compilation_config:
kwargs["compilation_config"] = compilation_config
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs: