by
Meta: Iris Zhang, Less Wright, Rodrigo Kumpera, Chien-Chin Huang, IBM: Davis Wertheimer, Supriyo Chakraboty, Sophia Wen, Raghu Ganti, Mudhakar Srivatsa, Seethrami Seelam
Last year, IBM Research began collaborating with us to onboard Fully Sharded Data Parallelism (FSDP) for their large foundation models. They became interested as FSDP is a PyTorch native offering for scaling their distributed training efforts on IBM Cloud.
We are pleased to share that, in collaboration with IBM, we have achieved substantial checkpointing speedups for large models (72x vs the original PyTorch 1.13 save speed), proven model and optimizer checkpoint scaling to 30B parameters, and enabled cloud first training using FSDP + Distributed Checkpoint on S3 backends.
What is a Distributed Checkpoint?
Distributed checkpointing is the PyTorch native solution for saving and loading PyTorch models and optimizer states from multiple ranks, as well as supporting dynamically changing world sizes between reloads.
PyTorch Distributed Checkpoint (DCP) APIs were introduced in PyTorch 1.13, and are included as an official prototype feature in PyTorch 2.0.
Distributed checkpoint is different from torch.save() and torch.load() in a few significant ways:
- DCP produces multiples files per checkpoint, with at least one file per rank,
- DCP operates in place, meaning that the model should allocate its data first and the Distributed Checkpoint will then use the storage.
A major improvement from 1.13 to 2.0 includes adding sharded_state_dict support