Introducing PyTorch-DirectML: Train your machine learning models on any GPU

Adele Parsons

The Windows AI team is excited to announce the first preview of DirectML as a backend to PyTorch for training ML models! This release is our first step towards unlocking accelerated machine learning training for PyTorch on any DirectX12 GPU on Windows and the Windows Subsystem for Linux (WSL).

In order for you to take advantage of DirectML within PyTorch, today we are releasing a preview PyTorch-DirectML package, which provides scoped support for convolutional neural networks (CNNS). In this package, DirectML is integrated with the PyTorch framework by introducing a new device named “DML,” which calls on the DirectML APIs and PyTorch Tensor primitives. There is minimal overhead calling into the DirectML operators, and the DirectML backend works in the same way as other existing PyTorch backends. We co-engineered with AMD, Intel, and NVIDIA to enable this hardware accelerated training experience for PyTorch.

Image PyTorch DirectML Arc

The PyTorch-DirectML package is easy to install, and only requires changing one line of code in an existing script. To get started, you can install the package by calling:

pip install pytorch-directml

or download the package on PyPI. In order to use the DirectML backend, the only code change necessary is to specify it by calling See the code below for an example:

a = torch.tensor([[1, 2, 3], [1, 2, 3]]).to(“dml”)

That’s it! You’re all set to start running your PyTorch training scripts!


Try it out and stay involved

We encourage you to try PyTorch-DirectML and give us feedback on how the DirectML backend is working for you!

To help you get started, we created a tutorial for training SqueezeNet and ResNet on GitHub.

Please leave questions, suggestions, or issues here on GitHub. Our team is constantly monitoring the community’s feedback, and we look forward to hearing from you!








Discussion is closed. Login to edit/delete existing comments.

  • Seyyed Hossein Hasanpour 0

    Thanks a lot to everyone at Microsoft who participated in this and made it a reality. Its a great addition to the Pytorch community.

    • Sunil Chandrasekharan 0

      Gives error
      RuntimeError: generic_type: cannot initialize type “TensorProtoDataType”: an object with that name is already defined
      when try to run
      RuntimeError Traceback (most recent call last)
      ~\AppData\Local\Temp/ipykernel_1952/ in
      —-> 1 a = torch.tensor([[1, 2, 3], [1, 2, 3]]).to(“dml”)

      c:\users\sunil\appdata\local\programs\python\python38\lib\site-packages\pyforest\ in __getattr__(self, attribute)
      68 # called for undefined attribute and returns the attribute of the imported module
      69 def __getattr__(self, attribute):
      —> 70 self.__maybe_import__()
      71 return eval(f”{self.__imported_name__}.{attribute}”)

      c:\users\sunil\appdata\local\programs\python\python38\lib\site-packages\pyforest\ in __maybe_import__(self)
      35 def __maybe_import__(self):
      36 self.__maybe_import_complementary_imports__()
      —> 37 exec(self.__import_statement__, globals())
      38 # Attention: if the import fails, the next lines will not be reached
      39 self.__was_imported__ = True

      c:\users\sunil\appdata\local\programs\python\python38\lib\site-packages\pyforest\ in

      c:\users\sunil\appdata\local\programs\python\python38\lib\site-packages\torch\ in
      194 if USE_GLOBAL_DEPS:
      195 _load_global_deps()
      –> 196 from torch._C import *
      198 # Appease the type checker; ordinarily this binding is inserted by the

      RuntimeError: generic_type: cannot initialize type “TensorProtoDataType”: an object with that name is already defined

Feedback usabilla icon