mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 02:58:54 +07:00
feat: add support for MUSA and NPU devices in device management functions
This commit is contained in:
@@ -20,6 +20,8 @@ def enable_custom_logits_processors() -> bool:
|
||||
compute_capability = "8.0"
|
||||
elif hasattr(torch, 'gcu') and torch.gcu.is_available():
|
||||
compute_capability = "8.0"
|
||||
elif hasattr(torch, 'musa') and torch.musa.is_available():
|
||||
compute_capability = "8.0"
|
||||
else:
|
||||
logger.info("CUDA not available, disabling custom_logits_processors")
|
||||
return False
|
||||
|
||||
@@ -189,6 +189,12 @@ def model_init(model_name: str):
|
||||
elif device_name.startswith("gcu"):
|
||||
if torch.gcu.is_bf16_supported():
|
||||
bf_16_support = True
|
||||
elif device_name.startswith("musa"):
|
||||
if torch.musa.is_bf16_supported():
|
||||
bf_16_support = True
|
||||
elif device_name.startswith("npu"):
|
||||
if torch.npu.is_bf16_supported():
|
||||
bf_16_support = True
|
||||
|
||||
if model_name == 'layoutreader':
|
||||
# 检测modelscope的缓存目录是否存在
|
||||
|
||||
@@ -90,7 +90,11 @@ def get_device():
|
||||
if torch.gcu.is_available():
|
||||
return "gcu"
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
if torch.musa.is_available():
|
||||
return "musa"
|
||||
except Exception as e:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
|
||||
|
||||
@@ -426,6 +426,9 @@ def clean_memory(device='cuda'):
|
||||
elif str(device).startswith("gcu"):
|
||||
if torch.gcu.is_available():
|
||||
torch.gcu.empty_cache()
|
||||
elif str(device).startswith("musa"):
|
||||
if torch.musa.is_available():
|
||||
torch.musa.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@@ -464,5 +467,8 @@ def get_vram(device) -> int:
|
||||
elif str(device).startswith("gcu"):
|
||||
if torch.gcu.is_available():
|
||||
total_memory = round(torch.gcu.get_device_properties(device).total_memory / (1024 ** 3)) # 转为 GB
|
||||
elif str(device).startswith("musa"):
|
||||
if torch.musa.is_available():
|
||||
total_memory = round(torch.musa.get_device_properties(device).total_memory / (1024 ** 3)) # 转为 GB
|
||||
|
||||
return total_memory
|
||||
|
||||
Reference in New Issue
Block a user