feat: add compilation configuration support for MUSA devices in utils and vlm_analyze

This commit is contained in:
myhloli
2026-01-22 18:21:04 +08:00
parent ffecb89e33
commit b9465238f5
2 changed files with 44 additions and 1 deletions

View File

@@ -100,4 +100,19 @@ def set_default_batch_size() -> int:
except Exception as e:
logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
batch_size = 1
return batch_size
return batch_size
def set_compilation_config() -> dict:
device = get_device()
compilation_config = {}
if str(device).startswith('musa'):
try:
import torch
if torch.musa.is_available():
compilation_config = {
"simple_cuda_graph": true
}
except Exception as e:
pass
return compilation_config

View File

@@ -99,6 +99,20 @@ class ModelSingleton:
import vllm
except ImportError:
raise ImportError("Please install vllm to use the vllm-engine backend.")
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], str):
try:
kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
except json.JSONDecodeError:
logger.warning(
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
del kwargs["compilation_config"]
else:
compilation_config = set_compilation_config()
if compilation_config:
kwargs["compilation_config"] = compilation_config
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs:
@@ -114,6 +128,20 @@ class ModelSingleton:
from vllm.v1.engine.async_llm import AsyncLLM
except ImportError:
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], str):
try:
kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
except json.JSONDecodeError:
logger.warning(
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
del kwargs["compilation_config"]
else:
compilation_config = set_compilation_config()
if compilation_config:
kwargs["compilation_config"] = compilation_config
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs: