25 Feb 2021 | David Krueger, Ethan Caballero, Joern-Henrik Jacobsen, Amy Zhang, Jonathan Binas, Dinghuai Zhang, Remi Le Prieur, Aaron Courville
The article introduces Risk Extrapolation (REx) as a method for achieving out-of-distribution (OOD) generalization in machine learning. REx addresses the challenge of distributional shift by assuming that variations in training domains are representative of potential test-time variations, and that test-time shifts may be more extreme. The method is formulated as a robust optimization problem over a perturbation set of possible test domains, aiming to minimize the worst-case performance across these domains. Two variants of REx are proposed: Minimax-REx (MM-REx), which considers affine combinations of training risks, and Variance-REx (V-REx), which penalizes the variance of training risks. REx is shown to be effective in uncovering invariant relationships between inputs and targets, which are statistical relationships that hold across all domains in the perturbation set. This capability allows REx to outperform other methods like Invariant Risk Minimization (IRM) in scenarios involving covariate shift and requiring invariant prediction. Theoretical analysis demonstrates that REx can recover causal mechanisms of the targets and provides robustness to changes in input distribution. Experiments on tasks such as Colored MNIST and simulated robotics tasks show that REx significantly outperforms IRM in settings involving covariate shift, although IRM has an advantage in cases where some domains are intrinsically harder. The method's ability to handle both covariate and interventional shifts makes it a powerful approach for OOD generalization.The article introduces Risk Extrapolation (REx) as a method for achieving out-of-distribution (OOD) generalization in machine learning. REx addresses the challenge of distributional shift by assuming that variations in training domains are representative of potential test-time variations, and that test-time shifts may be more extreme. The method is formulated as a robust optimization problem over a perturbation set of possible test domains, aiming to minimize the worst-case performance across these domains. Two variants of REx are proposed: Minimax-REx (MM-REx), which considers affine combinations of training risks, and Variance-REx (V-REx), which penalizes the variance of training risks. REx is shown to be effective in uncovering invariant relationships between inputs and targets, which are statistical relationships that hold across all domains in the perturbation set. This capability allows REx to outperform other methods like Invariant Risk Minimization (IRM) in scenarios involving covariate shift and requiring invariant prediction. Theoretical analysis demonstrates that REx can recover causal mechanisms of the targets and provides robustness to changes in input distribution. Experiments on tasks such as Colored MNIST and simulated robotics tasks show that REx significantly outperforms IRM in settings involving covariate shift, although IRM has an advantage in cases where some domains are intrinsically harder. The method's ability to handle both covariate and interventional shifts makes it a powerful approach for OOD generalization.