Processing math: 100%

[Paper Review] ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

2022. 6. 27. 11:41Paper Review

1. Extended Introduction

DP(Data Parallelism) runs out of memory for models with more than 1.4B parameters on current generation of GPUs with 32GB memory. MP(Model Parallelism) requires model refactoring and have significant communication overhead. To overcome this limitations, we first analyze the full spectrum of memory consumption of the existing systems on model training and classify it into two parts: 1). model states and 2). residual states. We develop ZeRO which optimizes memory efficiency on both while obtaining high compute and communication efficiency.

1.1. Optimizing Model State Memory

1). DP has good compute/communication efficiency but poor memory efficiency while MP has good memory efficiency but poor compute/communication efficiency. 2). Existing approaches maintain all the model states required over the entire training process statically, even though not all model states are required all the time during the training. Based on these observations, we develop ZeRO-DP, ZeRO-powered data parallelism, that achieves the computation/communication efficiency of DP while achieving memory efficiency of MP, ZeRO-DP 1). removes the memory state redundancies across data-parallel processes by partitioning the model states instead of replicating them, and 2). it retains the compute/communication efficiency by retaining the computational granularity and communication volume of DP using a dynamic communication schedule during training.

1.2. Optimizing Residual State Memory

We develop ZeRO-R to optimize the residual memory consumed by residual states:

  1. ZeRO-R optimizes activation memory by identifying and removing activation replication in existing MP approaches through activation paritioning. It also offloads activations to CPU when appropriate.
  2. ZeRO-R defines appropriate size for temporary buffers to strike for a balance of memory and computation efficiency.
  3. ZeRO-R proactively manages memory based on the different lifetime of tensors, preventing memory fragmentation.

1.3. Implementation & Evaluation

The complete set of optimizations of ZeRO could run models with trillion parameters, however, training time can be impractically long. Therefore, we implement and evaluate ZeRO-100B Pos+g of ZeRO-DP plus ZeRO-R which efficiently support models with 10x parameters than SOTA. The results show:

  1. Model Size: Combined with MP, ZeRO-100B runs 170B parameter model efficiently.
  2. Speed: Due to memory efficiency, training speed improve more than 10x.
  3. Scalability: We observe super linear speedup in the regime of 64-400 GPUs. Because ZeRO-DP reduces the memory footprint of the model states as we increase the DP degree allowing us to fit larger batch sizes per GPU resulting in better performance.
  4. Democratization of Large Model Training: ZeRO-100B powers data scientist to train models with up to 13B parameters without model refactoring. Data scientists can thus experiment freely with large models without worrying about parallelism.
  5. New SOTA Model: ZeRO powers the largest language model with 17B parameters and record-breaking accuracy, Turing-NLG.

2. Related Work

2.1. Data, Model and Pipeline Parallelism

Parallelization is a key strategy on training large models at scale. DP replicates model parameters on each device. MP splits the model among processes, in vertical way.
In PP, 1). model functionalities such as tied-weights and batch-normalization are difficult to implement due to horizontal splitting and micro-batching, respectively. 2). G-pipe requires a batch size proportional to number of pipeline partitions to hide the pipeline bubble which adversely affect convergance rate. 3). PipeDream keeps multiple copies of stale parameters to hide the pipeline bubble, making it less memory efficient. In contrast, ZeRO obtains the same or better memory efficiency than PP without incurring functionality, performance and convergence related restrictions of PP.

2.2. Non-parallelism based approach to reduce memory

2.2.1 Reducing Activation Memory

Multiple efforts have focused on reducing the memory footprint of activations through compression, activation checkpointing, or live analysis. Activation memory reduction in ZeRO-R works in parallel with activation checkpointing.

2.2.2 CPU Offload

Some research offload model states to CPU memory through algorithmic design or virtualized memory, respectively. Up to 50% of training time can spent on GPU-CPU-GPU transfers. On rares cases, ZeRO-R may offload just the activation checkpoints for very large models to improve performance.

2.2.3 Memory Efficient Optimizer

Some previous work focus on reducing memory consumption of adaptive optimization methods by maintaining coarser-grained statistics of model parameters and gradients, with potential impact on model convergence guarantees. ZeRO is orthogonal to these efforts.

2.3 Training Optimizers

Adaptive optimization methods are crucial to achieving SOTA performance and accuracy for effective model training for large models. They maintain fine-grained first-order and second-order statistics for each model parameter and gradient at the cost of significant memory footprint. ZeRO can reduce the memory footprint of these optimizers by orders of magnitude.

3. Where Did All the Memory Go?

Let's take a step back to examine the memory consumption of the current training system.

3.1. Model States: Optimizer States, Gradients and Parameters

Majority of the device memory is consumed by model states during training. For example, if we train a model with ADAM, there has to be enough memory to hold both 1). momentum and 2). variance of the gradients. In addition, there needs to be enough memory to store 3). the gradients and 4). the weights themselves.

3.1.1. Mixed-Precision Training

The state-of-the-art approach to train large models on the current generation of NVIDIA GPUs is via mixed precision. During mixed-precision training, both the forward and backward propagation are performed using fp16 parameters, activations, and gradient. However, to effectively compute and apply the updates at optimize step, the mixed-precision optimizer keeps an fp32 copy of the parameters and gradient as well as an fp32 optimizer states.

3.2. Residual Memory Consumption

3.2.1. Activations

Activations can take up a significant amount of memory during training. Activation checkpointing is a common approach to reduce the activation memory by approximately the square root of the total activations at the expense of 33% re-computation overhead. Despite the significant reduction, the activation memory can grow quite large for bigger models even with activation checkpointing.

3.2.2. Temporary buffers

Temporary buffers used for storing intermediate results. Operations such as gradient all-reduce, or gradient norm computation tend to fuse all the gradients into a single flattened buffer before applying the operation in an effort to improve throughout. But when the size of the model is large, these temporary buffer sizes are non-trivial.

3.2.3. Memory Fragmentation

It is possible to run out of usable memory even when there is plenty of available memory. This can happen with memory fragmentation. A request for a memory will fail if there isn't enough contiguous memory to satisfy it, even if the total available memory is larger than requested.

4. ZeRO: Insights and Overview

We present an overview of the optimizations and the insights behind, which allows ZeRO to reduce memory footprint while remaining efficient. Please note efficiency is a key here.

4.1. Insights and Overview: ZeRO-DP

ZeRO powered DP is based on three key insights:

  1. DP has better scaling efficiency than MP because MP reduces the granularity of the computation while also increasing the communication overhead.
  2. DP is memory inefficient as model states are stored redundantly across all data-parallel processes.
  3. Both DP and MP keep all the model states needed over the entire training process, but not everything is required all the time.

Based on these insights, ZeRO-DP retains the training efficiency of DP, while achieving the memory efficiency of MP by partitioning model states and using dynamic communication schedule.

4.2. Insights and Overview: ZeRO-R

4.2.1. Reducing Activation Memory

Two key insights are:

  1. MP partitions the model states but often requires replication of the activation memory.
  2. The arithmetic intensity (amount of computation per iterationamount of activation checkpoints per iteration) increases linearly with hidden dimension making it possible to hide the data-movement cost for the activation checkpoints.

ZeRO removes the memory redundancies in MP by partitioning the activations checkpoints across GPUs, and use all-gather to reconstruct them on demand.

4.2.2. Reducing Activation Memory

ZeRO-R uses constant size buffers to avoid temporary buffers from blowing up as the model size increases.

4.2.3. Managing fragmented Memory

Memory fragmentation is a result of interleaving between short lived and long lived memory objects. So, ZeRO performs on-the-fly memory defragementation by moving activation checkpoints and gradients to pre-contiguous memory buffers.

5. Deep Dive into ZeRO-DP

While the existing DP approach replicates the model states at each device and introduces significant memory overhead, ZeRO-DP elimiates this memory redundancy by partitioning them (optimizer states, gradients and parameters) across data parallel processes.

5.1. Pos: Optimizer State Partitioning

Each data parallel process only needs to store and update 1Nd of the total optimizer states and then only update 1Nd of the parameters. We perform an all-gather across the data parallel process at the end of each training step to get the fully updated parameters. This lead 4x memory reduction.

5.2. Pg: Gradient Partitioning

Each data parallel process only needs the reduced gradients for the corresponding parameters. Therefore, as each gradient of each layer becomes available during backward propagation, we only reduce them on the data parallel process responsible for updating the corresponding parameters. (overlap computation and communication) After the reduction we no longer need the gradients and their memory can be released. This lead 12x memory reduction.

5.3. Pp: Parameter Partitioning

Each process only stores the parameters corresponding to its partition. When parameters outside of its partition are required for forward and backward propagation, they are received from the appropriate data parallel process through broadcast. This lead memory reduction proportional to Nd.

5.4. Implication on Model Size

Note that, with Nd=64, ZeRO can train models with up to 7.5B, 14B, and 135B parameters using Pos, Pos+g, and Pos+g+p, respectively. When Nd=1024, ZeRO with all of its optimizations enabled (Pos+g+p) could train models with 1 TRILLION parameters!

6. Deep Dive into ZeRO-R

6.1. Pa: Partitioned Activation Checkpointing

MP requires a replication of the activations, resulting in redundant copies of the activations across model parallel GPUs. ZeRO eliminates this redundancy by partitioning the activations during forward propagation, and use all-gather operation to re-materialize a replicated copy of the activations during the backward propagation. It works in conjunction with activation checkpointing. Furthermore, in the case of very large models and very limited device memory, these partitioned activation checkpoints can also be offloaded to the CPU. We refer to this as Pa+cpu. This reduces the activation footprint by a factor proportional to the MP degree.

6.2. CB: Constant Size Buffers

During training, the computational efficiency of some operations achieve higher when get bigger input size. So, high performance libraries fuses all the parameters into a single buffer before applying these operations. However, the memory overhead of the fused buffers is proportional to the model size, and can become inhibiting. To address, we simply use constant-size fused buffer when the model becomes too large.

6.3. MD: Memory Defragmentation

Interleaving between short lived (ex. discarded activations, activation gradients) and long lived memory (ex. checkpointed activations, parameter, gradients) cause fragmentation. This leads two issues, 1). OOM due to lack of contiguous memory, 2). memory allocator spending significant time to search for a contiguous memory to safisfy a memory request. ZeRO does memory defragmentation on-the-fly by pre-allocating contiguous memory chunks for activation checkpoints and gradients, and copy them over to the pre-allocated memory.

7. Communication Analysis of ZeRO-DP

7.1. Communication Volumne with Pos+g

ZeRO only requires a scatter-reduce operation on the gradients, incurring communication volume of Ψ. After each process updates the partition of the parameters that it is responsible for, an all-gather is performed to collect all the updated parameters from all the data parallel process. This also incurs a communication volume of Ψ. So total communication volume is 2Ψ.

7.2. Communication Volumne with Pos+g+p

We reschedule the parameter all-gather by spreading it across the entire forward propagation, and discarding the parameters one they have been used. Note that this all-gather needs to happend once again for the backward propagation in the reverse order. This communication volume is 2Ψ. Additionally, we have to consider scatter-reduce to compute gradients which has Ψ communication volume. The total volume is therefore 3Ψ.

8. Communication Analysis of ZeRO-R

We compare the communication volume of partitioned activation checkpointing in ZeRO-R with baseline MP.

In Megatron-LM with activation checkpointing, each transformer block perform two all-reduce operations in the forward propagation, two all-reduce for forward re-computation and two more in the backward propagation. The total communication per block is 12×message_size since communication volume of an all-reduce is 2×message_size.

When ZeRO-R partitions activation checkpoints, it requires an additional all-gather operation before the forward recomputation of the backward propagation on each activation checkpoint. In general, we checkpoint the input activation for each transformer block, requiring one all-gather per transformer block. The communication overhead Pa is therefore message_size.

When MP is used in conjunction with DP, Pa can be used to reduce the data-parallel communication volume by an order of magnitude at the expense of a 10% increase in model-parallel communication volume. Because Pa reduces the activation memory consumption by the MP degree allowing for a proportional increase in batch size.

In extreme cases where DP communication volume is the major bottleneck due to a small batch size even with Pa, Pa+cpu can improve efficiency by increasing the batch size as long as the CPU data transfer overhead is less than the DP communication volume overhead.

9. Step Towards 1 Trillion Parameters

Getting to a trillion parameters will inevitably happen, but the road will be full of hurdles, surprises and innovations. ZeRO addresses one of the most fundamental challenges from a system perspective: the ability to fit a model of this scale on current hardware while allowing it to train with good system scalability. In other word, ZeRO vastly increase the efficiently-runnable model size. It enables the current generation of hardware to run significantly larger models without requiring fine-grained model parallelism to go across the node boundaries. The results show that ZeRO, with all optimizations, could fit more than 1 Trillion parameters on 1024 GPUs using DP only. Alternatively, when combined with model parallelism, ZeRO could fit more than 1 Trillion parameters on 1024 GPUs with 16-way model parallelism. But it could still require significant amount of compute power, which is lacking in today's AI clusters.

10. Implementation and Evaluation

We focus our implementation on supporting efficient training of models with ~100B parameters while trainable within a reasonable time frame on current hardware.

10.1. Implementation and Methodology

Implementation: ZeRO-100B implemented by PyTorch include the full set of optimizations in Pos+g and ZeRO-R. Its interface is compatible with any model implemented as an torch.nn.module. User do not need to modify their model. This can be combined with any form of MP including, Megatron-LM.

Hardware: 400 V100 GPUs (25 DGX-2 nodes) with 800 Gbps internode communication bandwidth.

Baseline: For experiments without MP, we use torch's distributed data parallel (DDP) as baseline. For experiments with MP, we use Megatron-LM.

ZeRO: Experiments without MP, use the ZeRO-powered DP implemented in ZeRO-100B. Experiments with MP, combine ZeRO-powered DP with MP of Megatron-LM.

Model Configurations: The models presented in this section are GPT-2 like transformer based models.

10.2. Speed and Model size

ZeRO-100B efficiently run model with up to 170B parameters on 400 GPUs, more than 8x bigger than Megatron-LM. The results show that ZeRO-100B achieves a sustained throughput of 15 PetaFlops on average for models with 8B to 100B parameters. In comparison, the baseline MP performance degrades quickly with the increase in model size. ZeRO-100B achieves up to 10x speedup over baseline, significantly outperforming on large models.

10.3. Super-Linear Scalability

ZeRO-100B demonstrates super-linear scalability for very large model sizes. Pos+g reduces per GPU memory consumption of ZeRO-100B with increase in DP degree, allowing ZeRO-100B to fit larger batch sizes per GPU.

10.4. Democratizing Large Model Training

ZeRO does not require any changes to the model itself and it can be used as simple as baseline DP while delivering significantly boosted model size and speed.

10.5. Memory and Performance Analysis

  ZeRO-DP ZeRO-R Max Model Size
C1 Pos CB+MD 40B
C2 Pos CB+MD+Pa 60B
C3 Pos+g CB+MD 50B
C4 Pos+g CB+MD+Pa 140B
C5 Pos+g CB+MD+Pa+cpu 150B

Maximum Model Size: The results are above.

Max Cached Memory: The difference in memory consumption between C2 and C3 depends on the size of the model states in comparison to the activation memory. It is note worthy that the cached memory does not decrease from C4 to C5 for 40B but it does for 100B.

Max Achievable Performance: Note that performance improvement corresponds to decrease in memory consumption between the optimizations. The only caveat is the performance drop between C4 and C5 for 60B parameter model despite of lower memory consumption.

10.6. Turing-NLG, the SOTA language model with 17B parameters

Turing-NLG achieved the new SOTA for language models with Webtext-103 perplexity of 10.21. Turing-NLG was trained end-to-end using ZeRO-100B. ZeRO-100B achieves a sustained throughput of 41.4 TeraFlops/GPU for this model.

11. Concluding Remarks

From a HPC and system perspective, we believe that ZeRO represents a revolutionary transformation in the large model training landscape. While our implementation, ZeRO-100B, enables 8x increase in model sizes, over 10x in throughput improvement, achieves super-linear speedups on modern GPU clusters. Perhaps, what we feel most optimistic about ZeRO is that it imposes no hurdles on the data scientists.