Here's a draft of the tech report on the model training method I've been experimenting with, "Parallax".
chutes.ai/parallax.pdf
TL;DR: MoE models' params are mostly routed experts, and you can massively reduce VRAM and FLOPS per participant by splitting up those experts. You can also offload the expert training to commodity hardware further saving compute/VRAM per island.
The crazy cool thing about these sketches is, you can actually onboard workers nearly instantly (sync time with ternary weights is a few MB), and they never need to download or stream the raw datasets (sketches contain all the work they need to do and are tiny). You'd probably want the first couple layers of the model in your own infra if you had sensitive data because otherwise you could do gradient inversion attacks to reconstitute the raw text, but beyond the first couple layers and not knowing which layer/expert you're training I think it's infeasible so privacy is pretty baked in.
Decoupled DiLoCo/RDA-diloco style backbone sync, surrogates for non-owned routed experts with low rank updates to sync those, tiered sync cadences for various components, "sketches" to offload expert work, etc.
20b tested two different ways, plenty of small model iterations, and 176b params just to prove out feasibility.
There are hundreds of additional experiments and loads of data we could also highlight as well, but the guts are there.
Variations:
- freeze routed expert weights instead of using surrogates, eliminates adam state, backward pass, etc., though you'd need different sync methods vs. low rank updates to the surrogates (the surrogates are already tiny, rank-8 updates to those even smaller)
- hierarchical parallax, i.e. each node itself becomes an rda diloco style multi-learner which then syncs with the outer/global islands (the point here is to enable GPUs without NVLink/etc. and reduce GPU<=>GPU comms to maximize MFU on less-capable/commodity GPUs)
- pipeline parallelize each island itself such that each island can decompose the backbone etc.