PyTorch provides GPU acceleration on Apple Silicon and Intel-based Macs through the Metal Performance Shaders (MPS) backend, enabling significantly faster computation on Apple GPUs.
Overview
Metal is Apple’s API for programming the GPU on macOS. The MPS backend in PyTorch leverages Metal Performance Shaders to accelerate deep learning workloads on Apple hardware, including M1, M2, M3, and later Apple Silicon chips.
MPS support is available on macOS 12.3+ and requires an Apple Silicon (M1/M2/M3) or AMD GPU on Intel-based Macs.
Installation
MPS support is included in the standard PyTorch installation for macOS:
# Install PyTorch with MPS support (conda)
conda install pytorch torchvision torchaudio -c pytorch
# Or using pip
pip install torch torchvision torchaudio
No additional installation steps are required. MPS support is automatically included in macOS builds of PyTorch.
Device Management
Checking MPS Availability
import torch
# Check if MPS is available
if torch.backends.mps.is_available():
print("MPS device is available")
print(f"Number of MPS devices: {torch.mps.device_count()}")
else:
print("MPS device is not available")
if not torch.backends.mps.is_built():
print("PyTorch was not built with MPS support")
Using MPS Device
import torch
# Create tensors on MPS device
mps_device = torch.device("mps")
# Method 1: Specify device during creation
tensor = torch.randn(1000, 1000, device=mps_device)
# Method 2: Move existing tensor to MPS
tensor_cpu = torch.randn(1000, 1000)
tensor_mps = tensor_cpu.to('mps')
# Method 3: Using to() with device object
tensor_mps = tensor_cpu.to(mps_device)
Tensor Operations
Creating and Moving Tensors
import torch
# Create tensor on MPS
x = torch.randn(100, 100, device='mps')
y = torch.randn(100, 100, device='mps')
# Perform operations on MPS
z = x @ y # Matrix multiplication on GPU
result = z.sum()
# Move back to CPU
result_cpu = result.to('cpu')
Model Training on MPS
import torch
import torch.nn as nn
# Define model
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# Move model to MPS device
mps_device = torch.device("mps")
model = model.to(mps_device)
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for inputs, targets in train_loader:
# Move data to MPS
inputs = inputs.to(mps_device)
targets = targets.to(mps_device)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
For optimal performance, keep data and model on the MPS device throughout training to avoid expensive CPU-GPU transfers.
Memory Management
import torch
# Get current allocated memory
allocated = torch.mps.current_allocated_memory()
print(f"Currently allocated: {allocated / 1e9:.2f} GB")
# Get total driver allocated memory (includes cached)
driver_allocated = torch.mps.driver_allocated_memory()
print(f"Driver allocated: {driver_allocated / 1e9:.2f} GB")
# Get recommended max working set size
recommended_max = torch.mps.recommended_max_memory()
print(f"Recommended max: {recommended_max / 1e9:.2f} GB")
Memory Cleanup
# Free unused cached memory
torch.mps.empty_cache()
# Set memory fraction limit (0-2 range)
# 0 = unlimited, 1 = recommended max, >1 = beyond recommended
torch.mps.set_per_process_memory_fraction(0.8)
Setting memory fraction to 0 allows unlimited allocations, which may cause system instability if memory is exhausted. Use values between 0.5-1.0 for safety.
Memory Limits
# Set conservative memory limit (80% of recommended)
torch.mps.set_per_process_memory_fraction(0.8)
# Check memory before large allocation
if torch.mps.current_allocated_memory() > torch.mps.recommended_max_memory() * 0.9:
print("Warning: Approaching memory limit")
torch.mps.empty_cache()
Synchronization
Device Synchronization
import torch
# Asynchronous operations on MPS
tensor = torch.randn(10000, 10000, device='mps')
result = tensor @ tensor.T
# Wait for all operations to complete
torch.mps.synchronize()
# Now safe to access result on CPU
result_cpu = result.to('cpu')
print(result_cpu)
Unlike CUDA, MPS doesn’t have user-controllable streams. All operations are submitted to the Metal command queue and executed asynchronously.
Random Number Generation
import torch
# Set random seed for reproducibility
torch.mps.manual_seed(42)
# Get RNG state
rng_state = torch.mps.get_rng_state()
# Perform random operations
random_tensor = torch.randn(100, 100, device='mps')
# Restore RNG state
torch.mps.set_rng_state(rng_state)
# Generate random seed
torch.mps.seed()
Events for Synchronization
import torch
from torch.mps import Event
# Create events
start_event = Event()
end_event = Event()
# Record start event
start_event.record()
# Perform operations
output = model(input_tensor.to('mps'))
# Record end event
end_event.record()
# Wait for operations to complete
torch.mps.synchronize()
# Events can be used for synchronization
if end_event.query():
print("Operations completed")
PyTorch MPS allows you to compile and run custom Metal compute shaders.
import torch
# Define Metal shader source
shader_source = """
kernel void full(
device float* out,
constant float& val,
uint idx [[thread_position_in_grid]]
) {
out[idx] = val;
}
"""
# Compile shader library
lib = torch.mps.compile_shader(shader_source)
# Create output tensor
x = torch.zeros(16, device="mps")
# Execute custom kernel
lib.full(x, 3.14)
torch.mps.synchronize()
print(x) # tensor([3.14, 3.14, ..., 3.14])
Custom Metal shaders can be used to implement operations not available in PyTorch or optimize specific kernels for Apple GPUs.
Best Practices
import torch
import torch.nn as nn
# 1. Keep data on MPS device
mps_device = torch.device("mps")
model = model.to(mps_device)
# 2. Use larger batch sizes when possible
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=128, # Larger batches better utilize GPU
num_workers=2,
pin_memory=False # Not needed for MPS
)
# 3. Avoid frequent CPU-GPU transfers
for inputs, targets in train_loader:
# Move once at batch start
inputs = inputs.to(mps_device)
targets = targets.to(mps_device)
# All operations stay on GPU
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
Benchmarking
import torch
import time
def benchmark_operation(device, size=1000, iterations=100):
x = torch.randn(size, size, device=device)
# Warmup
for _ in range(10):
_ = x @ x.T
if device == 'mps':
torch.mps.synchronize()
# Benchmark
start = time.time()
for _ in range(iterations):
result = x @ x.T
if device == 'mps':
torch.mps.synchronize()
end = time.time()
elapsed = (end - start) / iterations * 1000
print(f"{device.upper()}: {elapsed:.2f} ms per iteration")
# Compare CPU vs MPS
benchmark_operation('cpu')
benchmark_operation('mps')
Profiling
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU],
record_shapes=True,
with_stack=True
) as prof:
model(input_tensor.to('mps'))
torch.mps.synchronize()
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
MPS profiling currently focuses on CPU-side overhead. For detailed GPU profiling, use Xcode’s Metal System Trace instrument.
Supported Operations
Most PyTorch operations are supported on MPS, including:
- Tensor creation and manipulation
- Mathematical operations (add, mul, matmul, etc.)
- Neural network layers (Linear, Conv2d, RNN, etc.)
- Optimizers (Adam, SGD, etc.)
- Loss functions
- Autograd and backpropagation
Checking Operation Support
import torch
try:
# Attempt operation on MPS
x = torch.randn(10, 10, device='mps')
result = torch.fft.fft(x)
print("FFT is supported on MPS")
except RuntimeError as e:
print(f"Operation not supported: {e}")
# Fallback to CPU
result = torch.fft.fft(x.cpu()).to('mps')
Common Issues
NotImplementedError: The operator ‘X’ is not currently implemented for the MPS device.Some operations may not be supported on MPS. Fallback to CPU:try:
result = unsupported_op(tensor_mps)
except NotImplementedError:
result = unsupported_op(tensor_mps.cpu()).to('mps')
Out of Memory ErrorsIf you encounter OOM errors on MPS:
- Reduce batch size
- Set memory fraction:
torch.mps.set_per_process_memory_fraction(0.7)
- Clear cache:
torch.mps.empty_cache()
- Close other GPU-intensive applications
Fork Safety
import torch
import multiprocessing as mp
# MPS may have issues with fork
# Use spawn method for multiprocessing
if __name__ == '__main__':
mp.set_start_method('spawn')
# Your multiprocessing code here
System Requirements
Minimum Requirements
- macOS 12.3 or later
- Python 3.8 or later
- PyTorch 1.12 or later
- Apple Silicon (M1/M2/M3) or AMD GPU
Checking System Compatibility
import torch
import platform
print(f"macOS version: {platform.mac_ver()[0]}")
print(f"PyTorch version: {torch.__version__}")
print(f"MPS built: {torch.backends.mps.is_built()}")
print(f"MPS available: {torch.backends.mps.is_available()}")
Comparison: MPS vs CPU
import torch
import time
def compare_performance(model, input_size, device):
model = model.to(device)
input_tensor = torch.randn(input_size, device=device)
# Warmup
for _ in range(10):
_ = model(input_tensor)
if device == 'mps':
torch.mps.synchronize()
# Benchmark
start = time.time()
for _ in range(100):
output = model(input_tensor)
if device == 'mps':
torch.mps.synchronize()
elapsed = time.time() - start
return elapsed
model = nn.Sequential(
nn.Linear(1000, 500),
nn.ReLU(),
nn.Linear(500, 10)
)
cpu_time = compare_performance(model, (32, 1000), 'cpu')
mps_time = compare_performance(model, (32, 1000), 'mps')
print(f"CPU time: {cpu_time:.2f}s")
print(f"MPS time: {mps_time:.2f}s")
print(f"Speedup: {cpu_time/mps_time:.2f}x")
API Reference
Key MPS functions and classes:
torch.backends.mps.is_available() - Check MPS availability
torch.backends.mps.is_built() - Check if PyTorch was built with MPS
torch.mps.device_count() - Get number of MPS devices (0 or 1)
torch.mps.synchronize() - Wait for all operations to complete
torch.mps.empty_cache() - Release unused cached memory
torch.mps.current_allocated_memory() - Get current memory usage
torch.mps.driver_allocated_memory() - Get total driver allocated memory
torch.mps.set_per_process_memory_fraction() - Set memory limit
torch.mps.manual_seed() - Set random seed
torch.mps.Event - Synchronization primitive
torch.mps.compile_shader() - Compile custom Metal shaders
For more information, see the MPS backend documentation.