feat: add support for MUSA and NPU devices in device management functions

This commit is contained in:
myhloli
2026-01-22 15:45:26 +08:00
parent 313ec8afa0
commit 6a75b39940
4 changed files with 19 additions and 1 deletions

View File

@@ -20,6 +20,8 @@ def enable_custom_logits_processors() -> bool:
compute_capability = "8.0"
elif hasattr(torch, 'gcu') and torch.gcu.is_available():
compute_capability = "8.0"
elif hasattr(torch, 'musa') and torch.musa.is_available():
compute_capability = "8.0"
else:
logger.info("CUDA not available, disabling custom_logits_processors")
return False

View File

@@ -189,6 +189,12 @@ def model_init(model_name: str):
elif device_name.startswith("gcu"):
if torch.gcu.is_bf16_supported():
bf_16_support = True
elif device_name.startswith("musa"):
if torch.musa.is_bf16_supported():
bf_16_support = True
elif device_name.startswith("npu"):
if torch.npu.is_bf16_supported():
bf_16_support = True
if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在

View File

@@ -90,7 +90,11 @@ def get_device():
if torch.gcu.is_available():
return "gcu"
except Exception as e:
pass
try:
if torch.musa.is_available():
return "musa"
except Exception as e:
pass
return "cpu"

View File

@@ -426,6 +426,9 @@ def clean_memory(device='cuda'):
elif str(device).startswith("gcu"):
if torch.gcu.is_available():
torch.gcu.empty_cache()
elif str(device).startswith("musa"):
if torch.musa.is_available():
torch.musa.empty_cache()
gc.collect()
@@ -464,5 +467,8 @@ def get_vram(device) -> int:
elif str(device).startswith("gcu"):
if torch.gcu.is_available():
total_memory = round(torch.gcu.get_device_properties(device).total_memory / (1024 ** 3)) # 转为 GB
elif str(device).startswith("musa"):
if torch.musa.is_available():
total_memory = round(torch.musa.get_device_properties(device).total_memory / (1024 ** 3)) # 转为 GB
return total_memory