We’re excited to introduce the Azure Storage Connector for PyTorch (azstoragetorch
), a new library that brings seamless, performance-optimized integration between Azure Storage and PyTorch. The library makes it easy to access and store data in Azure Blob Storage directly within your training workflows.
What is PyTorch?
PyTorch is a widely used open-source machine learning framework, known for its flexibility and strong support for both research and production deployments. Training models with PyTorch often involves handling large volumes of data. This can include loading massive datasets, saving and restoring model checkpoints, and managing data pipelines. These pipelines can live across environments like local machines, cloud virtual machines, and distributed compute clusters, adding to complexity.
How can the Azure Storage Connector for PyTorch help?
The Azure Storage Connector for PyTorch (azstoragetorch
) bridges the powerful storage capabilities of Azure Storage with PyTorch, enabling seamless integrations with key PyTorch APIs for your model training workflows.
The package supports model checkpointing directly with Azure Storage with torch.save()
and torch.load()
and directly loading data from Azure Storage to PyTorch Dataset
classes.
Use the Azure Storage Connector for PyTorch
Installation
The Azure Storage Connector for PyTorch is listed on PyPI, and you can install it using your favorite package manager. This example utilizes pip
:
pip install azstoragetorch
This installs the azstoragetorch
library and other dependencies such as torch
and azure-storage-blob
.
Authentication
The Azure Storage Connector for PyTorch package’s interfaces default to using the Azure Identity library’s DefaultAzureCredential
, which automatically retrieves Microsoft Entra ID tokens based on your current environment. For more information, see DefaultAzureCredential overview.
This means that as long as there exists a credential on your machine in the chained credential list for DefaultAzureCredential
, your credentials are securely handled automatically and you’re ready to start using the library.
Save and Load a Model Checkpoint
The Azure Storage Connector for PyTorch provides the azstoragetorch.io.BlobIO
class to save and load models directly to and from Azure Blob Storage. This class adheres to the file-like object accepted when using torch.save()
and torch.load()
for model checkpointing in PyTorch.
The BlobIO
class takes an Azure Storage Blob URL and the mode you’d like to operate in – either write ("wb"
) for saving or read ("rb"
) for loading. Because the BlobIO
class is file-like in Python, it can be safely handled when using the with
statement.
import torch
import torchvision.models # Install this separately, e.g. ``pip install torchvision``
from azstoragetorch.io import BlobIO
# The URL to your container
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"
# Your model of choice, in this case ResNet18 for image recognition
model = torchvision.models.resnet18(weights="DEFAULT")
# Save model weights to an Azure Storage Blob named "model_weights.pth" in the container
with BlobIO(f"{CONTAINER_URL}/model_weights.pth", "wb") as f:
torch.save(model.state_dict(), f)
# Load model weights from model_weights.pth in the Azure Storage Container
with BlobIO(f"{CONTAINER_URL}/model_weights.pth", "rb") as f:
model.load_state_dict(torch.load(f))
Sample: Use PyTorch Datasets on Azure Storage
PyTorch has two primitives for loading samples, DataSet
and DataLoader
. The Azure Storage Connector for PyTorch has implementations for both of PyTorch datasets, map-style and iterable-style.
The azstoragetorch.datasets.BlobDataset
class is a map-style dataset, enabling random access to data samples. The azstoragetorch.datasets.IterableBlobDataset
class is an iterable dataset, which should be used when working on large datasets that may not fit in memory.
Both classes support two methods: from_container_url()
and from_blob_urls()
. The from_container_url()
method instantiates a dataset by listing blobs in a container, and the from_blob_urls()
method takes a list of blob URLs when creating a dataset.
These Dataset
integrations fit naturally into a PyTorch workflow. Let’s dive into an image example using the caltech101 dataset and the resnet18
model.
The caltech101 dataset contains about 9,000 images across various categories and the resnet18
model is a residual neural network introduced in the paper Deep Residual Learning for Image Recognition.
Prerequisites and Setup
First, install prerequisite packages azstoragetorch
, pillow
, and torchvision
.
pip install azstoragetorch pillow torchvision
Once you have the packages installed, ensure you have the caltech101 dataset in your Azure Storage Account. To ease this setup process, clone the GitHub repository and run the provided setup file.
Lastly, it’s helpful to understand the importance of transforming data for PyTorch operations. The default output of dataset samples in the azstoragetorch
package are dictionaries representing a blob in the dataset. It’s often necessary to transform this data into the shape needed for your PyTorch workflows.
To override the default output, we can provide a transform
callable in the from_blob_urls
or from_container_url
methods that accept an argument of type azstoragetorch.datasets.Blob
. The transform callable in the following code sample is based on the PyTorch documentation. To learn more about how to use the transform
callable in the azstoragetorch
library, visit the documentation.
from torch.utils.data import DataLoader
import torchvision.transforms
from PIL import Image
from azstoragetorch.datasets import BlobDataset
# Method to transform blob data to a torch.Tensor
# To learn more about why these particular transforms were chosen,
# Visit the documentation site: https://pytorch.org/hub/pytorch_vision_resnet/
def blob_to_category_and_tensor(blob):
with blob.reader() as f:
img = Image.open(f).convert("RGB")
img_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
img_tensor = img_transform(img)
# Get second to last component of blob name which will be the image category
# Example: blob.blob_name -> datasets/caltech101/dalmatian/image_0001.jpg
# category -> dalmatian
category = blob.blob_name.split("/")[-2]
return category, img_tensor
# The URL to your container
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"
# Initialize dataset with the azstoragetorch library
dataset = BlobDataset.from_container_url(
CONTAINER_URL,
prefix="datasets/caltech101",
transform=blob_to_category_and_tensor,
)
# Set up data loader
loader = DataLoader(dataset, batch_size=32)
for categories, img_tensors in loader:
print(categories, img_tensors.size())
Conclusion
The Azure Storage Connector for PyTorch is designed around the principle that cloud storage integration for your ML workflows shouldn’t require learning new paradigms. Key features include:
- Zero configuration: Automatic credential discovery means no setup code
- Familiar patterns: If you know
open()
and PyTorch datasets, you already know this library - Framework integration: Direct compatibility with
torch.save()
,torch.load()
, andDataLoader
- Flexible access: Support for both container-wide and specific blob/object access patterns for reads and writes
- Debugging friendly: Clear error messages and standard Python exceptions
Install azstoragetorch today to enable several use cases with your Machine Learning workflows using your data stored in Azure Blob Storage:
- Distributed Training: Save and load model checkpoints across multiple nodes without shared file systems
- Model Sharing: Easily share trained models across teams and environments
- Dataset Management: Access large datasets stored in Azure Blob Storage without local storage constraints
- Experimentation: Quickly iterate on different models and datasets without data movement overhead
The Azure Storage Connector for PyTorch is in Public Preview and we’re actively seeking feedback from the community. The project is open source and available on GitHub where we’d love to get your feedback, feature requests, and future integrations we should include.
Resources
- Azure Storage Connector for PyTorch (azstoragetorch) (Preview)
- azstoragetorch Samples and Quickstart
- Data-intensive AI Training and Inferencing with Azure Blob Storage
- Quickstart: Azure Blob Storage client library for Python
For feature requests, bug reports, or general support, open an issue in the repository on GitHub.
0 comments
Be the first to start the discussion.