7 Feb 2024 | Allan Zhou, Chelsea Finn, James Harrison
This paper introduces Universal Neural Functionals (UNFs), a method for constructing permutation-equivariant neural functionals that operate on arbitrary weight spaces. Unlike previous methods that were limited to simple feedforward networks, UNFs can be applied to any weight space, including those with complex architectures such as recurrent networks and transformers. The algorithm automatically constructs permutation-equivariant models for any collection of tensors whose dimensions can permute according to a shared set of permutations. This approach allows for the creation of deep permutation-equivariant models, which are referred to as universal neural functionals (UNFs).
The paper demonstrates that UNFs can be substituted into existing learned optimizer designs, achieving promising improvements over prior methods when optimizing small image classifiers and language models. The results suggest that learned optimizers can benefit from considering the symmetry structure of the weight space they optimize. The authors open-source their library for constructing UNFs at https://github.com/AllanYangZhou/universal_neural_functional.
The paper also presents experiments showing that UNFs outperform prior methods in tasks such as predicting the generalization of recurrent sequence-to-sequence models and training learned optimizers for various architectures and datasets. The results indicate that UNFs are effective at processing weights and gradients of convolutional image classifiers, recurrent sequence-to-sequence models, and Transformer language models. In particular, UNFs show promising improvements over existing learned optimizer designs in small-scale experiments.
The paper also discusses the limitations of UNFs, including the challenge of applying them to heterogeneous weight-space inputs and the computational tractability of more complex architectures. The authors conclude that UNFs have the potential to improve the scalability and applicability of neural functionals to weight-space tasks.This paper introduces Universal Neural Functionals (UNFs), a method for constructing permutation-equivariant neural functionals that operate on arbitrary weight spaces. Unlike previous methods that were limited to simple feedforward networks, UNFs can be applied to any weight space, including those with complex architectures such as recurrent networks and transformers. The algorithm automatically constructs permutation-equivariant models for any collection of tensors whose dimensions can permute according to a shared set of permutations. This approach allows for the creation of deep permutation-equivariant models, which are referred to as universal neural functionals (UNFs).
The paper demonstrates that UNFs can be substituted into existing learned optimizer designs, achieving promising improvements over prior methods when optimizing small image classifiers and language models. The results suggest that learned optimizers can benefit from considering the symmetry structure of the weight space they optimize. The authors open-source their library for constructing UNFs at https://github.com/AllanYangZhou/universal_neural_functional.
The paper also presents experiments showing that UNFs outperform prior methods in tasks such as predicting the generalization of recurrent sequence-to-sequence models and training learned optimizers for various architectures and datasets. The results indicate that UNFs are effective at processing weights and gradients of convolutional image classifiers, recurrent sequence-to-sequence models, and Transformer language models. In particular, UNFs show promising improvements over existing learned optimizer designs in small-scale experiments.
The paper also discusses the limitations of UNFs, including the challenge of applying them to heterogeneous weight-space inputs and the computational tractability of more complex architectures. The authors conclude that UNFs have the potential to improve the scalability and applicability of neural functionals to weight-space tasks.