Introducing PyTorch-DirectML: Train your machine learning models on any GPU

Adele Parsons

The Windows AI team is excited to announce the first preview of DirectML as a backend to PyTorch for training ML models! This release is our first step towards unlocking accelerated machine learning training for PyTorch on any DirectX12 GPU on Windows and the Windows Subsystem for Linux (WSL).

In order for you to take advantage of DirectML within PyTorch, today we are releasing a preview PyTorch-DirectML package, which provides scoped support for convolutional neural networks (CNNS). In this package, DirectML is integrated with the PyTorch framework by introducing a new device named “DML,” which calls on the DirectML APIs and PyTorch Tensor primitives. There is minimal overhead calling into the DirectML operators, and the DirectML backend works in the same way as other existing PyTorch backends. We co-engineered with AMD, Intel, and NVIDIA to enable this hardware accelerated training experience for PyTorch.

Image PyTorch DirectML Arc

The PyTorch-DirectML package is easy to install, and only requires changing one line of code in an existing script. To get started, you can install the package by calling:

pip install pytorch-directml

or download the package on PyPI. In order to use the DirectML backend, the only code change necessary is to specify it by calling See the code below for an example:

a = torch.tensor([[1, 2, 3], [1, 2, 3]]).to(“dml”)

That’s it! You’re all set to start running your PyTorch training scripts!


Try it out and stay involved

We encourage you to try PyTorch-DirectML and give us feedback on how the DirectML backend is working for you!

To help you get started, we created a tutorial for training SqueezeNet and ResNet on GitHub.

Please leave questions, suggestions, or issues here on GitHub. Our team is constantly monitoring the community’s feedback, and we look forward to hearing from you!