feat: add macOS MPS device support to all PyTorch models #2110
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Closes #2050
Add automatic MPS (Metal Performance Shaders) device support for Apple Silicon Macs across all PyTorch model benchmarks.
Approach
Introduced a centralized
get_torch_device()utility inpytorch_utils.pythat selects the best available device with priority: CUDA > MPS > CPU.Updated all 27 model classes and 2 module-level device selections to use this function, replacing the previous inline pattern:
Files changed
qlib/contrib/model/pytorch_utils.py— Addedget_torch_device()alongside existingcount_parameters()qlib/contrib/model/— Updated device selectionqlib/contrib/model/pytorch_tra.py— Updated module-level deviceqlib/contrib/data/dataset.py— Updated module-level deviceBackward compatibility
GPU="cuda:1").hasattr(torch.backends, "mps")guard ensures compatibility with older PyTorch versions that lack MPS support.Test Plan
count_parameters()function preserved unchangedhasattrguard for PyTorch versions without MPS backend