7 Feb 2024 | Allan Zhou, Chelsea Finn, James Harrison
This paper introduces a method for constructing *universal neural functionals* (UNFs), which are permutation-equivariant models that can process weight-space features from any neural network architecture. The authors address the challenge of handling complex weight spaces, such as those with recurrence or residual connections, by automatically constructing equivariant models. UNFs are designed to be flexible and can be applied to various weight spaces, including those of recurrent neural networks (RNNs) and Transformers. The paper demonstrates the effectiveness of UNFs in improving the performance of learned optimizers on tasks involving small image classifiers and language models. The authors also provide an open-source implementation of their algorithm, which is compatible with most JAX neural network libraries. The results suggest that learned optimizers can benefit from considering the symmetry structure of the weight space they optimize.This paper introduces a method for constructing *universal neural functionals* (UNFs), which are permutation-equivariant models that can process weight-space features from any neural network architecture. The authors address the challenge of handling complex weight spaces, such as those with recurrence or residual connections, by automatically constructing equivariant models. UNFs are designed to be flexible and can be applied to various weight spaces, including those of recurrent neural networks (RNNs) and Transformers. The paper demonstrates the effectiveness of UNFs in improving the performance of learned optimizers on tasks involving small image classifiers and language models. The authors also provide an open-source implementation of their algorithm, which is compatible with most JAX neural network libraries. The results suggest that learned optimizers can benefit from considering the symmetry structure of the weight space they optimize.