feat: add cudagraph_capture_sizes to compilation configuration for MUSA devices

This commit is contained in:
myhloli
2026-01-22 18:40:22 +08:00
parent e8548eddde
commit 294105c1b0

View File

@@ -111,6 +111,7 @@ def set_compilation_config() -> dict:
import torch
if torch.musa.is_available():
compilation_config = {
"cudagraph_capture_sizes": [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 30],
"simple_cuda_graph": True
}
except Exception as e: