GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding

GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding

30 Jun 2020 | Dmitry Lepikhin, HyoukJoong Lee, Yuanzhong Xu, Dehao Chen, Orhan Firat, Yanping Huang, Maxim Krikun, Noam Shazeer, Zhifeng Chen
GShard is a module that enables efficient scaling of large neural networks through conditional computation and automatic sharding. It consists of lightweight annotation APIs and an extension to the XLA compiler, allowing users to express parallel computation patterns with minimal changes to existing model code. GShard was used to scale a multilingual neural machine translation Transformer model with Sparsely-Gated Mixture-of-Experts (MoE) layers beyond 600 billion parameters. This model was trained on 2048 TPU v3 accelerators in 4 days, achieving far superior translation quality for 100 languages to English compared to prior art. The paper discusses the practical challenges of scaling large models, including efficient model parallelism, super-linear scaling of computation cost, infrastructure scalability for giant model representation, and non-trivial efforts for implementing partitioning strategies. To address these challenges, the paper introduces a design principle that enables sub-linear computation cost and constant compilation time. The model is built using a sparse scaling approach, replacing every other feed-forward layer with a MoE layer, which allows for efficient training and inference. The paper also describes the implementation of the model using GShard, which enables efficient parallel execution by allowing users to annotate tensors with partitioning policies. The GShard module includes a set of simple APIs for annotations and a compiler extension in XLA for automatic parallelization. The model is expressed as if it were running on a single device with huge memory and computation capacity, and the compiler automatically partitions the computation based on the annotations and heuristics. The paper also discusses the SPMD (Single Program Multiple Data) partitioning technique, which allows for efficient parallel execution by generating a single program to run on all devices. This technique is used to partition the computation graph and handle cross-device communication. The paper also describes the communication primitives used in the SPMD partitioner, including CollectivePermute, AllGather, AllReduce, and AllToAll, which are essential for efficient parallel execution. The paper concludes with an evaluation of the model on a massive multilingual machine translation task, demonstrating the effectiveness of GShard in scaling large models while maintaining training efficiency and translation quality. The results show that the model can process 1 trillion tokens in under 4 days, achieving superior translation quality for 100 languages to English. The paper also discusses the challenges of scaling large models and the importance of efficient parallel execution in achieving high-quality translation.GShard is a module that enables efficient scaling of large neural networks through conditional computation and automatic sharding. It consists of lightweight annotation APIs and an extension to the XLA compiler, allowing users to express parallel computation patterns with minimal changes to existing model code. GShard was used to scale a multilingual neural machine translation Transformer model with Sparsely-Gated Mixture-of-Experts (MoE) layers beyond 600 billion parameters. This model was trained on 2048 TPU v3 accelerators in 4 days, achieving far superior translation quality for 100 languages to English compared to prior art. The paper discusses the practical challenges of scaling large models, including efficient model parallelism, super-linear scaling of computation cost, infrastructure scalability for giant model representation, and non-trivial efforts for implementing partitioning strategies. To address these challenges, the paper introduces a design principle that enables sub-linear computation cost and constant compilation time. The model is built using a sparse scaling approach, replacing every other feed-forward layer with a MoE layer, which allows for efficient training and inference. The paper also describes the implementation of the model using GShard, which enables efficient parallel execution by allowing users to annotate tensors with partitioning policies. The GShard module includes a set of simple APIs for annotations and a compiler extension in XLA for automatic parallelization. The model is expressed as if it were running on a single device with huge memory and computation capacity, and the compiler automatically partitions the computation based on the annotations and heuristics. The paper also discusses the SPMD (Single Program Multiple Data) partitioning technique, which allows for efficient parallel execution by generating a single program to run on all devices. This technique is used to partition the computation graph and handle cross-device communication. The paper also describes the communication primitives used in the SPMD partitioner, including CollectivePermute, AllGather, AllReduce, and AllToAll, which are essential for efficient parallel execution. The paper concludes with an evaluation of the model on a massive multilingual machine translation task, demonstrating the effectiveness of GShard in scaling large models while maintaining training efficiency and translation quality. The results show that the model can process 1 trillion tokens in under 4 days, achieving superior translation quality for 100 languages to English. The paper also discusses the challenges of scaling large models and the importance of efficient parallel execution in achieving high-quality translation.
Reach us at info@study.space
[slides] GShard%3A Scaling Giant Models with Conditional Computation and Automatic Sharding | StudySpace