Distributed learning?
Recently,
@PrimeIntellect have announced their 10B distributed learning (
twitter.com/PrimeIntellect/s…), what is it exactly?
Going back to the origin, Federated Learning (FL,
arxiv.org/abs/1602.05629) aims to train a model across a fleet of phones. To handle the low bandwidth between them, the gist is that each phone will perform multiple training steps (forward/backward & update) independently from each other. After a hundred of steps, each phone's replica of the model computes its trajectory (aka an outer gradient): it's simply a delta in the parameters space between where the model is now vs where it was at first.
Then, an all-reduce (thus no need of centralized server!) averages all the replicas outer gradients. Thus the peak communication is as high as in data-parallel (
pytorch.org/docs/stable/gene…), but because it's only done rarely, its cost is amortized.
Now, those "outer gradients" are quite different from your classical gradients, but we can actually use an optimizer on those! FedOpt (
arxiv.org/abs/2003.00295) proposes to optimize it with Adam, further accelerating training.
Most of the FL algorithms used however as inner/local optimizer SGD. This is suboptimal to train current large-scale transformer. In DiLoCo (
arxiv.org/abs/2311.08105) from my team at
@GoogleDeepMind, we use Adam as inner optimizer, but more importantly Nesterov as outer optimizer -- this simple change is ridiculously powerful:
With DiLoCo, you can now be as flops/token efficient as data-parallel training while using two orders of magnitude less bandwidth. A bunch of papers (
arxiv.org/abs/2409.13198,
arxiv.org/abs/2405.10853,
arxiv.org/abs/2407.07852) have then re-used that winning recipes on large-scale transformers, including OpenDiLoCo.
In particular, OpenDiLoCo from
@PrimeIntellect, is training a 10B (!) DiLoCo model across the world using Hivemind (
github.com/learning-at-home/…), while synchronizing every 100 steps, and downcasting the outer gradients in int8. That's a 400x bandwidth reduction. This reduction is so big that, as
@samsja19 noted, their speed bottleneck now is checkpointing on disk, not communication.
Godspeed
@PrimeIntellect! 🫡
The other approach that has been made somewhat public is Distro from
@NousResearch. For now, there is very little details, so it's hard to say how legit it is (read their report). But we know that one of the trick used is adding a regularization loss, forcing all replicas to stay close towards the global model that is synced once in a while. This is very reminiscent of ElasticSDG (
arxiv.org/abs/1412.6651 @ylecun) and FedProx (
arxiv.org/abs/1812.06127), which have also been re-visited recently with PAPA (
arxiv.org/abs/2304.03094).
I strongly suspect it is still combined with some kind of outer optimization, in order to remain flops-efficient. We'll see.
The main point of this post is to show that there is very little new things in the recent distributed works, but that i think a lot of research on federated/distributed was artificial and done on toy settings. Now, the time is ripe for distributed training, see this
@SemiAnalysis_ post:
semianalysis.com/p/multi-dat…. There is lot of alpha into re-visiting past ideas, and investing heavily on the engineering parts to make them work, for real, at scale.