28 Feb 2019 | Andrew Hard, Kanishka Rao, Rajiv Mathews, Swaroop Ramaswamy, Françoise Beaufays Sean Augenstein, Hubert Eichner, Chloé Kiddon, Daniel Ramage
This paper presents a study on using federated learning for next-word prediction in a mobile keyboard. The goal is to train a language model for predicting the next word in a virtual keyboard without sharing user data with servers. The study compares server-based training using stochastic gradient descent (SGD) with federated learning using the FederatedAveraging algorithm. The federated approach is shown to achieve better prediction recall, demonstrating the feasibility of training language models on client devices without exporting sensitive data.
The paper describes the use of a recurrent neural network (RNN) variant called Coupled Input-Forget Gates (CIFG) for next-word prediction. The model is trained on a dataset of user-generated text, with the federated learning approach allowing clients to train locally and share model updates with a server. The model is optimized for mobile devices, with a vocabulary of 10,000 words and a parameter count of 1.4 million. The model is trained using TensorFlow and TensorFlow Lite for on-device inference.
The federated learning process involves clients processing local data, sharing model updates with a server, and the server aggregating these updates to create a global model. The study evaluates the performance of the federated model against a baseline n-gram model on server-hosted logs and client-owned data. The results show that the federated model achieves better recall and prediction accuracy, with the top-1 recall improved by 5% and the prediction impression recall improved by 1% compared to the server-based model.
The study also evaluates the performance of the model in live production experiments with a subset of Gboard users. The results show that the federated model generates predictions that are 10% more likely to be clicked than the n-gram model. The results demonstrate that federated learning provides a preferable alternative to server-based training of neural language models, offering security and privacy benefits by training across a population of distributed devices.This paper presents a study on using federated learning for next-word prediction in a mobile keyboard. The goal is to train a language model for predicting the next word in a virtual keyboard without sharing user data with servers. The study compares server-based training using stochastic gradient descent (SGD) with federated learning using the FederatedAveraging algorithm. The federated approach is shown to achieve better prediction recall, demonstrating the feasibility of training language models on client devices without exporting sensitive data.
The paper describes the use of a recurrent neural network (RNN) variant called Coupled Input-Forget Gates (CIFG) for next-word prediction. The model is trained on a dataset of user-generated text, with the federated learning approach allowing clients to train locally and share model updates with a server. The model is optimized for mobile devices, with a vocabulary of 10,000 words and a parameter count of 1.4 million. The model is trained using TensorFlow and TensorFlow Lite for on-device inference.
The federated learning process involves clients processing local data, sharing model updates with a server, and the server aggregating these updates to create a global model. The study evaluates the performance of the federated model against a baseline n-gram model on server-hosted logs and client-owned data. The results show that the federated model achieves better recall and prediction accuracy, with the top-1 recall improved by 5% and the prediction impression recall improved by 1% compared to the server-based model.
The study also evaluates the performance of the model in live production experiments with a subset of Gboard users. The results show that the federated model generates predictions that are 10% more likely to be clicked than the n-gram model. The results demonstrate that federated learning provides a preferable alternative to server-based training of neural language models, offering security and privacy benefits by training across a population of distributed devices.