mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
feat: add compilation configuration support for MUSA devices in utils and vlm_analyze
This commit is contained in:
@@ -100,4 +100,19 @@ def set_default_batch_size() -> int:
|
||||
except Exception as e:
|
||||
logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
|
||||
batch_size = 1
|
||||
return batch_size
|
||||
return batch_size
|
||||
|
||||
|
||||
def set_compilation_config() -> dict:
|
||||
device = get_device()
|
||||
compilation_config = {}
|
||||
if str(device).startswith('musa'):
|
||||
try:
|
||||
import torch
|
||||
if torch.musa.is_available():
|
||||
compilation_config = {
|
||||
"simple_cuda_graph": true
|
||||
}
|
||||
except Exception as e:
|
||||
pass
|
||||
return compilation_config
|
||||
@@ -99,6 +99,20 @@ class ModelSingleton:
|
||||
import vllm
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm to use the vllm-engine backend.")
|
||||
|
||||
if "compilation_config" in kwargs:
|
||||
if isinstance(kwargs["compilation_config"], str):
|
||||
try:
|
||||
kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
|
||||
del kwargs["compilation_config"]
|
||||
else:
|
||||
compilation_config = set_compilation_config()
|
||||
if compilation_config:
|
||||
kwargs["compilation_config"] = compilation_config
|
||||
|
||||
if "gpu_memory_utilization" not in kwargs:
|
||||
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
|
||||
if "model" not in kwargs:
|
||||
@@ -114,6 +128,20 @@ class ModelSingleton:
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
|
||||
|
||||
if "compilation_config" in kwargs:
|
||||
if isinstance(kwargs["compilation_config"], str):
|
||||
try:
|
||||
kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
|
||||
del kwargs["compilation_config"]
|
||||
else:
|
||||
compilation_config = set_compilation_config()
|
||||
if compilation_config:
|
||||
kwargs["compilation_config"] = compilation_config
|
||||
|
||||
if "gpu_memory_utilization" not in kwargs:
|
||||
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
|
||||
if "model" not in kwargs:
|
||||
|
||||
Reference in New Issue
Block a user