Skip to main content
PyTorch provides GPU acceleration on Apple Silicon and Intel-based Macs through the Metal Performance Shaders (MPS) backend, enabling significantly faster computation on Apple GPUs.

Overview

Metal is Apple’s API for programming the GPU on macOS. The MPS backend in PyTorch leverages Metal Performance Shaders to accelerate deep learning workloads on Apple hardware, including M1, M2, M3, and later Apple Silicon chips.
MPS support is available on macOS 12.3+ and requires an Apple Silicon (M1/M2/M3) or AMD GPU on Intel-based Macs.

Installation

MPS support is included in the standard PyTorch installation for macOS:
# Install PyTorch with MPS support (conda)
conda install pytorch torchvision torchaudio -c pytorch

# Or using pip
pip install torch torchvision torchaudio
No additional installation steps are required. MPS support is automatically included in macOS builds of PyTorch.

Device Management

Checking MPS Availability

import torch

# Check if MPS is available
if torch.backends.mps.is_available():
    print("MPS device is available")
    print(f"Number of MPS devices: {torch.mps.device_count()}")
else:
    print("MPS device is not available")
    if not torch.backends.mps.is_built():
        print("PyTorch was not built with MPS support")

Using MPS Device

import torch

# Create tensors on MPS device
mps_device = torch.device("mps")

# Method 1: Specify device during creation
tensor = torch.randn(1000, 1000, device=mps_device)

# Method 2: Move existing tensor to MPS
tensor_cpu = torch.randn(1000, 1000)
tensor_mps = tensor_cpu.to('mps')

# Method 3: Using to() with device object
tensor_mps = tensor_cpu.to(mps_device)

Tensor Operations

Creating and Moving Tensors

import torch

# Create tensor on MPS
x = torch.randn(100, 100, device='mps')
y = torch.randn(100, 100, device='mps')

# Perform operations on MPS
z = x @ y  # Matrix multiplication on GPU
result = z.sum()

# Move back to CPU
result_cpu = result.to('cpu')

Model Training on MPS

import torch
import torch.nn as nn

# Define model
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# Move model to MPS device
mps_device = torch.device("mps")
model = model.to(mps_device)

# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for inputs, targets in train_loader:
        # Move data to MPS
        inputs = inputs.to(mps_device)
        targets = targets.to(mps_device)
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
For optimal performance, keep data and model on the MPS device throughout training to avoid expensive CPU-GPU transfers.

Memory Management

Memory Information

import torch

# Get current allocated memory
allocated = torch.mps.current_allocated_memory()
print(f"Currently allocated: {allocated / 1e9:.2f} GB")

# Get total driver allocated memory (includes cached)
driver_allocated = torch.mps.driver_allocated_memory()
print(f"Driver allocated: {driver_allocated / 1e9:.2f} GB")

# Get recommended max working set size
recommended_max = torch.mps.recommended_max_memory()
print(f"Recommended max: {recommended_max / 1e9:.2f} GB")

Memory Cleanup

# Free unused cached memory
torch.mps.empty_cache()

# Set memory fraction limit (0-2 range)
# 0 = unlimited, 1 = recommended max, >1 = beyond recommended
torch.mps.set_per_process_memory_fraction(0.8)
Setting memory fraction to 0 allows unlimited allocations, which may cause system instability if memory is exhausted. Use values between 0.5-1.0 for safety.

Memory Limits

# Set conservative memory limit (80% of recommended)
torch.mps.set_per_process_memory_fraction(0.8)

# Check memory before large allocation
if torch.mps.current_allocated_memory() > torch.mps.recommended_max_memory() * 0.9:
    print("Warning: Approaching memory limit")
    torch.mps.empty_cache()

Synchronization

Device Synchronization

import torch

# Asynchronous operations on MPS
tensor = torch.randn(10000, 10000, device='mps')
result = tensor @ tensor.T

# Wait for all operations to complete
torch.mps.synchronize()

# Now safe to access result on CPU
result_cpu = result.to('cpu')
print(result_cpu)
Unlike CUDA, MPS doesn’t have user-controllable streams. All operations are submitted to the Metal command queue and executed asynchronously.

Random Number Generation

import torch

# Set random seed for reproducibility
torch.mps.manual_seed(42)

# Get RNG state
rng_state = torch.mps.get_rng_state()

# Perform random operations
random_tensor = torch.randn(100, 100, device='mps')

# Restore RNG state
torch.mps.set_rng_state(rng_state)

# Generate random seed
torch.mps.seed()

Events for Synchronization

import torch
from torch.mps import Event

# Create events
start_event = Event()
end_event = Event()

# Record start event
start_event.record()

# Perform operations
output = model(input_tensor.to('mps'))

# Record end event
end_event.record()

# Wait for operations to complete
torch.mps.synchronize()

# Events can be used for synchronization
if end_event.query():
    print("Operations completed")

Custom Metal Shaders

PyTorch MPS allows you to compile and run custom Metal compute shaders.
import torch

# Define Metal shader source
shader_source = """
kernel void full(
    device float* out,
    constant float& val,
    uint idx [[thread_position_in_grid]]
) {
    out[idx] = val;
}
"""

# Compile shader library
lib = torch.mps.compile_shader(shader_source)

# Create output tensor
x = torch.zeros(16, device="mps")

# Execute custom kernel
lib.full(x, 3.14)

torch.mps.synchronize()
print(x)  # tensor([3.14, 3.14, ..., 3.14])
Custom Metal shaders can be used to implement operations not available in PyTorch or optimize specific kernels for Apple GPUs.

Performance Optimization

Best Practices

import torch
import torch.nn as nn

# 1. Keep data on MPS device
mps_device = torch.device("mps")
model = model.to(mps_device)

# 2. Use larger batch sizes when possible
train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=128,  # Larger batches better utilize GPU
    num_workers=2,
    pin_memory=False  # Not needed for MPS
)

# 3. Avoid frequent CPU-GPU transfers
for inputs, targets in train_loader:
    # Move once at batch start
    inputs = inputs.to(mps_device)
    targets = targets.to(mps_device)
    
    # All operations stay on GPU
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

Benchmarking

import torch
import time

def benchmark_operation(device, size=1000, iterations=100):
    x = torch.randn(size, size, device=device)
    
    # Warmup
    for _ in range(10):
        _ = x @ x.T
    
    if device == 'mps':
        torch.mps.synchronize()
    
    # Benchmark
    start = time.time()
    for _ in range(iterations):
        result = x @ x.T
    
    if device == 'mps':
        torch.mps.synchronize()
    
    end = time.time()
    elapsed = (end - start) / iterations * 1000
    print(f"{device.upper()}: {elapsed:.2f} ms per iteration")

# Compare CPU vs MPS
benchmark_operation('cpu')
benchmark_operation('mps')

Profiling

from torch.profiler import profile, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU],
    record_shapes=True,
    with_stack=True
) as prof:
    model(input_tensor.to('mps'))
    torch.mps.synchronize()

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
MPS profiling currently focuses on CPU-side overhead. For detailed GPU profiling, use Xcode’s Metal System Trace instrument.

Supported Operations

Most PyTorch operations are supported on MPS, including:
  • Tensor creation and manipulation
  • Mathematical operations (add, mul, matmul, etc.)
  • Neural network layers (Linear, Conv2d, RNN, etc.)
  • Optimizers (Adam, SGD, etc.)
  • Loss functions
  • Autograd and backpropagation

Checking Operation Support

import torch

try:
    # Attempt operation on MPS
    x = torch.randn(10, 10, device='mps')
    result = torch.fft.fft(x)
    print("FFT is supported on MPS")
except RuntimeError as e:
    print(f"Operation not supported: {e}")
    # Fallback to CPU
    result = torch.fft.fft(x.cpu()).to('mps')

Common Issues

NotImplementedError: The operator ‘X’ is not currently implemented for the MPS device.Some operations may not be supported on MPS. Fallback to CPU:
try:
    result = unsupported_op(tensor_mps)
except NotImplementedError:
    result = unsupported_op(tensor_mps.cpu()).to('mps')
Out of Memory ErrorsIf you encounter OOM errors on MPS:
  1. Reduce batch size
  2. Set memory fraction: torch.mps.set_per_process_memory_fraction(0.7)
  3. Clear cache: torch.mps.empty_cache()
  4. Close other GPU-intensive applications

Fork Safety

import torch
import multiprocessing as mp

# MPS may have issues with fork
# Use spawn method for multiprocessing
if __name__ == '__main__':
    mp.set_start_method('spawn')
    # Your multiprocessing code here

System Requirements

Minimum Requirements

  • macOS 12.3 or later
  • Python 3.8 or later
  • PyTorch 1.12 or later
  • Apple Silicon (M1/M2/M3) or AMD GPU

Checking System Compatibility

import torch
import platform

print(f"macOS version: {platform.mac_ver()[0]}")
print(f"PyTorch version: {torch.__version__}")
print(f"MPS built: {torch.backends.mps.is_built()}")
print(f"MPS available: {torch.backends.mps.is_available()}")

Comparison: MPS vs CPU

import torch
import time

def compare_performance(model, input_size, device):
    model = model.to(device)
    input_tensor = torch.randn(input_size, device=device)
    
    # Warmup
    for _ in range(10):
        _ = model(input_tensor)
    
    if device == 'mps':
        torch.mps.synchronize()
    
    # Benchmark
    start = time.time()
    for _ in range(100):
        output = model(input_tensor)
    
    if device == 'mps':
        torch.mps.synchronize()
    
    elapsed = time.time() - start
    return elapsed

model = nn.Sequential(
    nn.Linear(1000, 500),
    nn.ReLU(),
    nn.Linear(500, 10)
)

cpu_time = compare_performance(model, (32, 1000), 'cpu')
mps_time = compare_performance(model, (32, 1000), 'mps')

print(f"CPU time: {cpu_time:.2f}s")
print(f"MPS time: {mps_time:.2f}s") 
print(f"Speedup: {cpu_time/mps_time:.2f}x")

API Reference

Key MPS functions and classes:
  • torch.backends.mps.is_available() - Check MPS availability
  • torch.backends.mps.is_built() - Check if PyTorch was built with MPS
  • torch.mps.device_count() - Get number of MPS devices (0 or 1)
  • torch.mps.synchronize() - Wait for all operations to complete
  • torch.mps.empty_cache() - Release unused cached memory
  • torch.mps.current_allocated_memory() - Get current memory usage
  • torch.mps.driver_allocated_memory() - Get total driver allocated memory
  • torch.mps.set_per_process_memory_fraction() - Set memory limit
  • torch.mps.manual_seed() - Set random seed
  • torch.mps.Event - Synchronization primitive
  • torch.mps.compile_shader() - Compile custom Metal shaders
For more information, see the MPS backend documentation.