Skip to content

Conversation

@Ayush10
Copy link

@Ayush10 Ayush10 commented Jan 30, 2026

Summary

Closes #2050

Add automatic MPS (Metal Performance Shaders) device support for Apple Silicon Macs across all PyTorch model benchmarks.

Approach

Introduced a centralized get_torch_device() utility in pytorch_utils.py that selects the best available device with priority: CUDA > MPS > CPU.

def get_torch_device(GPU=0):
    if isinstance(GPU, str):
        return torch.device(GPU)
    if torch.cuda.is_available() and GPU >= 0:
        return torch.device(f"cuda:{GPU}")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

Updated all 27 model classes and 2 module-level device selections to use this function, replacing the previous inline pattern:

# Before (in every model file)
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")

# After
self.device = get_torch_device(GPU)

Files changed

  • qlib/contrib/model/pytorch_utils.py — Added get_torch_device() alongside existing count_parameters()
  • 25 model files in qlib/contrib/model/ — Updated device selection
  • qlib/contrib/model/pytorch_tra.py — Updated module-level device
  • qlib/contrib/data/dataset.py — Updated module-level device

Backward compatibility

  • CUDA users: No change in behavior. CUDA is still preferred when available.
  • CPU users: No change. CPU is the final fallback.
  • String GPU parameter: Still supported (e.g. GPU="cuda:1").
  • The hasattr(torch.backends, "mps") guard ensures compatibility with older PyTorch versions that lack MPS support.

Test Plan

  • All 27 model device selections updated
  • count_parameters() function preserved unchanged
  • hasattr guard for PyTorch versions without MPS backend
  • String GPU parameter handling preserved (pytorch_nn.py special case)

)

Add a `get_torch_device()` utility function to `pytorch_utils.py` that
selects the best available device with priority: CUDA > MPS > CPU.

Update all 27 PyTorch model classes and 2 module-level device selections
to use this centralized function, replacing the previous inline
`torch.device("cuda:%d" % GPU if torch.cuda.is_available() ...)` pattern.

This enables automatic GPU acceleration on Apple Silicon Macs via the
MPS backend, while maintaining full backward compatibility for CUDA and
CPU users.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

macos mps device support in benchmarks

1 participant