Transformer support for PyTorch with DirectML is here!

Adele Parsons

The latest release of PyTorch with DirectML is available today! This release brings support for training popular Transformer Models like GPT2, BERT, and Detection Transformers. To get started with training Transformer Models using PyTorch with DirectML, you can find a new sample on the DirectML GitHub. The sample covers training a PyTorch implementation of the Transformer Model in the popular paper “Attention is All You Need ” (Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, arxiv, 2017).

This release of PyTorch with DirectML also includes improved memory consumption capabilities to unlock faster performance and the ability to use larger batch sizes.

Finally, PyTorch with DirectML now follows a Plugin model with support for the latest version of PyTorch (1.13). After installing PyTorch, simply pip install torch-directml  to get started. Once you’ve installed the Torch-DirectML plugin, you can begin training AI models starting with the following lines:

import torch

import torch_directml

dml = torch_directml.device()

tensor = torch.tensor([1]).to(dml)  # Note that dml is a variable, not a string!

 

Please note that this release of the Torch-DirectML plugin is mapped to the “PrivateUse1” Torch backend. The new torch.directml.device() API is a convenient wrapper for sending your tensors to the DirectML device. Now you’re ready to train your models using PyTorch with DirectML!

Please leave any questions, suggestions, or issues here on GitHub. Our team is constantly engaging with the community and would love to hear your input!