Some notes from recent experiences building model training at scale.
Xander Dunn, 31 January 2024
These thoughts run across the stack, from things only infrastructure engineers need to worry about, to things the experimenters need to worry about. This isn't intended to be a comprehensive playbook. These are some of the things we've thought about and implemented recently in our training infrastructure. I'm sure there are better ways to do some of these things, so don't hesitate to send me your thoughts!
Consider Your Scale
If you've only got 8 machines, then the below will not be very useful to you. You'll want the logging and that's about it. If you've got 10,000 machines, then the below is not enough. For example, Google's training of Gemini Ultra across multiple clusters and multiple data centers would require additional tricks. DiLoCo and DiPaCo may be some of these tricks. This is a scale I have not yet worked on. We're in the middle, with machines measured in the hundreds and our largest single cluster is 8,000 devices.
Cluster Management
Periodic hardware failure testing. Loud hardware failures that prevent progress on that machine are actually the good kind of failure. The more pernicious failures are the silent ones that produce bad values with no errors. We proactively look for these failed devices by periodically testing our hardware. One approach is to run the same computation on multiple GPUs and check that the outputs are the same across GPUs. We have a Docker "preflight" image that we run on machines that exercises the full computation, memory, and communication gamut of a machine's devices. This has been very effective at exposing hardware issues. This check needs to be periodically run even if there are no signs of hardware degradation.
Hardware failure fault tolerance. You don't want to be in a situation where a single hardware failure causes the entire training run to grind to a halt for any significant amount of time. There are many ways to deal with this issue. The approach we took as a first step was fault tolerance across data parallel replicas of the model. The goal is for replicas to be hot-swappable, so that when a particular piece of hardware fails, that machine can be quickly swapped out and an all-reduce from a working replica brings the parameters in sync. This also gives us the ability to hot swap instances to proactively search for failed hardware. We can periodically take out machines and run preflight on them. This approach only works for data parallel > 1. The key metric here is time to (re)start the training run. When the run gets killed for any reason, how long does it take to reach the first loss output?
Warm pool. The above leads to a warm pool of machines that are ready to be swapped in when a particular machine experiences failure. This gets a replica back into the training run as quickly as possible so that all the other machines in the replica aren't sitting around doing nothing. Getting a machine into the run as quickly as possible requires efficient access to the compiled models that need to be run, as well as the data. This pool need not be very large. Generally when hardware failures occur, there are just 1 or 2 simultaneously. This approach requires that there are machines on the same spine that we're not using for training.
Getting data on every machine. At some point, your datasets will become too large to copy onto every machine. There are many potential solutions here. One is to have a single copy of the dataset on a network drive that is mounted to all machines in the experiment. On AWS these would be EFS or FSx drives. The performance of the drive need only exceed the speed at which the models consume samples. Each machine should be preemptively reading its samples from the drive and queuing them for model consumption in the background. If there's sufficient device memory for it, the next iteration's samples can be copied to the device before it's trained on.
Fault tolerant cluster manager. Most of the above points will need to be implemented in some scalable, fault tolerant manager. We implemented ours in custom Rust, which entailed a lot of hand-crafted fault tolerance considerations. Were I to do this from scratch I'd use Kubernetes, which has a lot of APIs that we effectively had to re-implement. A good reference on Kubernetes training infra is OpenAI's article on the topic. Of course, the infrastructure shouldn't be manually provisioned. Use infrastructure as code such as Terraform or Pulumi. These are the basics.
Monitoring
Centralized log aggregation. It's an obvious one and an important one. The ability to rapidly filter logs via complex expressions will become important. When a training run goes wrong, how will you wade through tens of millions of logs to understand what went wrong? We use DataDog and our monthly bill across 8,000 devices is ~$1,500. Make sure to impose a limit on daily spend. Consider some bug that prints a log every 10ms, and this is running on every device in an 8,000 device cluster. You will rack up hundreds of millions of logs and tens of thousands of dollars in DataDog bills in very short order.
Centralized hardware monitoring. In the case of TRN this is neuron-monitor and in the case of GPU this is DCGM. The data should be piped into prometheus or similar to visualize. These metrics can alert to issues. One way of interpreting how well the hardware is utilized is looking at temperature and energy consumption. These tools can also output info on the frequency of correctable memory errors. The device utilization shown in nvidia-smi isn't particularly useful on its own. I frequently see people on X posting images of nvidia-smi 100% utilization, which looks like a widespread misunderstanding. To disillusion yourself of the usefulness of nvidia-smi utilization, make an intentional comm deadlock across two GPUs and take a look at the nvidia-smi utilization. It will show perfect 100% utilization, but of course it isn't actually doing anything. nvidia-smi is showing what % of the past N polls some op was running, so 100% means some op has been running on every poll, but this says nothing about whether progress is being made or what % of max flops are being utilized. A much more interesting metric is tensor core utilization, which DCGM produces.
Utilization metrics. Hardware monitoring metrics lead to the ability to calculate utilization metrics such as MFU (see PaLM paper, Section 4.1) and "goodput." MFU is a good starting point, but that alone is not the end goal. You want to know what % of your flops utilization is producing valuable output. It's possible to burn all the devices with high utilization on useless computation, like divergent loss spikes, bad data, etc.
Experiment Layer
Job scheduling. Experimenters should be abstracted away from who is using what machines, and when. There should be some kind of experiment platform or job submission API that enables defining a run along with its resource requirements and then running it when those resources are available. Recording and presenting these experiment runs in some way is also vital for the experimenters to remember which run was doing what in the future.
Model compilation pipeline. You don't want to have to compile the model on every machine every time it's run. Machines are essentially ephemeral and compilation can take a long time, so it helps for a machine to know its place in the training and quickly grab the pre-compiled models. SPMD is an important property to achieve here with the graphs. As an example, consider send<>recv comm ops in your graphs. Each op will have different replica_ids defined on each machine, device, and shard. This produces a large number of unique graphs that must be compiled and distributed. This can be converted to an SPMD graph via the collective permute op, which reduces the number of unique graphs, reduces the graph compiling, and reduces the model downloads onto new machines.
Automated loss spike recovery. Divergences will happen. Ideally these will be automatically detected and recovered from checkpoint. For the worst offenders you'll need to skip the data, test/change out the hardware, and lower the learning rate.
Good error handling. There are some errors that should be show-stopping. But think very carefully about what those errors are. Most errors should not be showstoppers. For example, suppose the request to write your per-iteration metadata to the SQL database fails for some reason. That probably shouldn't be a showstopper. The error should be handled first by retrying the request, and ultimately if it never succeeds it shouldn't prevent the training from making progress. Hardware is too expensive to not make use of it when non-critical components are down.
Performance visualization. You'll want to be able to benchmark models and visualize the time each component of the model is taking. This will greatly aid the focus of performance optimization endeavors.
Iteration rate is paramount. Most experiments are going to fail, either because it crashes or because the model doesn't perform well. Therefore, it's extremely valuable to shorten the length of that feedback loop as much as possible. What is the time from submitting the first experiment to the first loss? This is so important it should be one of the North Star metrics for AI training infrastructure. And of course, problems like crashes should be worked out on small portions of the cluster before scaling to consume most of the cluster.