21 Sep 2020 | Aravind Srinivas*, Michael Laskin*, Pieter Abbeel
CURL: Contrastive Unsupervised Representations for Reinforcement Learning
**Abstract:**
CURL is a novel framework that combines instance contrastive learning with reinforcement learning (RL) to extract high-level features from raw pixels. It outperforms prior pixel-based methods, both model-based and model-free, on complex tasks in the DeepMind Control Suite and Atari Games, achieving 1.9x and 1.2x performance gains at the 100K environment and interaction steps benchmarks, respectively. CURL is the first image-based algorithm to nearly match the sample-efficiency of state-based feature methods on the DeepMind Control Suite. The code is open-sourced and available at <https://www.github.com/MishaLaskin/curl>.
CURL addresses the sample inefficiency of deep RL algorithms by using contrastive learning to extract useful semantic representations from high-dimensional observations. It trains a visual representation encoder by ensuring that the embeddings of data-augmented versions of the same observation match using a contrastive loss. The query observations are treated as anchors, while the key observations contain positive and negative samples constructed from the minibatch. The RL policy and/or value function are built on top of the query encoder, which is jointly trained with the contrastive and reinforcement learning objectives. CURL is a generic framework that can be integrated into any RL algorithm that relies on learning representations from high-dimensional images.
CURL significantly improves sample efficiency over prior pixel-based methods by performing contrastive learning simultaneously with an off-policy RL algorithm. When coupled with the Soft Actor-Critic (SAC) algorithm, CURL achieves 1.9x median higher performance than Dreamer on DMControl environments benchmarked at 100k environment steps. On Atari games benchmarked at 100k interaction steps, CURL coupled with a data-efficient version of Rainbow DQN achieves 1.2x median higher performance over prior methods, improving upon Efficient Rainbow on 19 out of 26 Atari games and surpassing human efficiency on two games.
CURL is designed to add minimal overhead in terms of architecture and model learning. It operates with the same latent space and architecture typically used for model-free RL and seamlessly integrates with the training pipeline without introducing multiple additional hyperparameters. The key contributions of this work include:
- Presenting CURL, a simple framework that integrates contrastive learning with model-free RL with minimal changes to the architecture and training pipeline.
- Empirically showing that contrastive learning combined with model-free RL outperforms prior state-of-the-art methods by 1.9x on DMControl and 1.2x on Atari compared to leading pixel-based methods.
- Demonstrating that a contrastive objective is the preferred self-supervised auxiliary task for achieving sample-efficiency compared to reconstruction-based methods.
- Enabling model-free methods to outperform state-of-the-art model-based methods in terms of data-efficiency.CURL: Contrastive Unsupervised Representations for Reinforcement Learning
**Abstract:**
CURL is a novel framework that combines instance contrastive learning with reinforcement learning (RL) to extract high-level features from raw pixels. It outperforms prior pixel-based methods, both model-based and model-free, on complex tasks in the DeepMind Control Suite and Atari Games, achieving 1.9x and 1.2x performance gains at the 100K environment and interaction steps benchmarks, respectively. CURL is the first image-based algorithm to nearly match the sample-efficiency of state-based feature methods on the DeepMind Control Suite. The code is open-sourced and available at <https://www.github.com/MishaLaskin/curl>.
CURL addresses the sample inefficiency of deep RL algorithms by using contrastive learning to extract useful semantic representations from high-dimensional observations. It trains a visual representation encoder by ensuring that the embeddings of data-augmented versions of the same observation match using a contrastive loss. The query observations are treated as anchors, while the key observations contain positive and negative samples constructed from the minibatch. The RL policy and/or value function are built on top of the query encoder, which is jointly trained with the contrastive and reinforcement learning objectives. CURL is a generic framework that can be integrated into any RL algorithm that relies on learning representations from high-dimensional images.
CURL significantly improves sample efficiency over prior pixel-based methods by performing contrastive learning simultaneously with an off-policy RL algorithm. When coupled with the Soft Actor-Critic (SAC) algorithm, CURL achieves 1.9x median higher performance than Dreamer on DMControl environments benchmarked at 100k environment steps. On Atari games benchmarked at 100k interaction steps, CURL coupled with a data-efficient version of Rainbow DQN achieves 1.2x median higher performance over prior methods, improving upon Efficient Rainbow on 19 out of 26 Atari games and surpassing human efficiency on two games.
CURL is designed to add minimal overhead in terms of architecture and model learning. It operates with the same latent space and architecture typically used for model-free RL and seamlessly integrates with the training pipeline without introducing multiple additional hyperparameters. The key contributions of this work include:
- Presenting CURL, a simple framework that integrates contrastive learning with model-free RL with minimal changes to the architecture and training pipeline.
- Empirically showing that contrastive learning combined with model-free RL outperforms prior state-of-the-art methods by 1.9x on DMControl and 1.2x on Atari compared to leading pixel-based methods.
- Demonstrating that a contrastive objective is the preferred self-supervised auxiliary task for achieving sample-efficiency compared to reconstruction-based methods.
- Enabling model-free methods to outperform state-of-the-art model-based methods in terms of data-efficiency.