Federated Learning

Overview

This is a system for training a text generation Recurrent Neural Network (RNN) using multiple clients/computers.

The system consists of a server and multiple clients that work together to train a language model. The server distributes training data and model weights to clients, while clients perform computation and send back rewards (loss values) that are used to update the model’s parameters.

Key Features

How It Works

  1. The server initializes model weights and starts listening for client connections
  2. Clients connect to the server and receive initial weights and configuration
  3. For each training step:
    • (Optional, can happen every N steps) Server distributes tokenized text data to clients
    • Clients compute loss values using the current model weights with added noise
    • Clients send loss values back to the server
    • Server normalizes the rewards and sends them back to clients
    • Clients update their local copy of the model weights based on these normalized rewards
    • Periodically, the server requests updated weights from a client to save checkpoints

Technical Details

Model Architecture

The system implements a recurrent neural network (RNN) language model with the following components:

Optimization

The model is trained using:

Communication Protocol

The system uses a custom communication protocol with:

More information on the math can be found in this paper I wrote for my Calc 4 final.
Federated Learning Project

Future Work