spent the last month building my own framework to train a diffusion model from scratch. it was hard
almost like i just learned to cast an ancient spell that requires lots of mysterious steps and ingredients. for a long time i was trying, and nothing happened. but when it worked it felt like magic
i've learned a lot so wanted to share a bit 🧵
- i'm doing *conditional* diffusion, trying to produce outputs x that depend on some inputs y. my biggest blocker was that the architectural biases matter here – you can NOT put the conditioning directly into the input, or the model will just learn to map y to x instead of using y to denoise the noisy input x. (the loss will go down but sampling will not work)
- thus the diffusion world has a zoo of "conditional" architectures that can be a little challenging to adapt for your problem. but you have to use one or else things just won't work
- apparently, architecture still matters in vision (sad). initialization, residuals, and extra normalization can make all the difference
- learning a small "probe" alongside your diffusion model is hugely valuable. you can just cut the gradients to the probe so that it doesn't affect training. this way you will know when you beat the baseline. (i'm not sure if this is common practice but it was invaluable for me)
- you need to incorporate sampling into training every-so-often. otherwise you will never figure out why your model doesn't work
- the normalization is super important. your input data needs to have ~mean 0 or std 1. otherwise learning might not work, or will be super slow
- in diffusion a lot of things can have the same shape but be different "types" in the sense that they're incompatible in some way. easy to make these bugs and the code will still run. and you often can find them by checking that the norms, stds, and means are approximately correct
- complex systems that you write from scratch will inevitably have tons of bugs. you can start with trying to learn the identity function (in diffusion just set the noise to zeros). if you can't do this something is broken. in my case this helped me realize one of my losses had a sign flipped
- in my opinion the loss after 1000 steps or so is usually a reliable signal for debugging architectural changes
- diffusion people look down on DDPM as old and outdated but turns out it's still "good enough for government work" and worked fine for me eventually
- wouldn't recommend the diffusers library. not sure it's really being developed anymore. heard the openai impl is much better
- in general building systems from scratch is a slow and frustrating way to do research and i would recommend most people just start with a good codebase and tweaking it to fit your problem. but if you build everything yourself you will learn a lot and feel a deep sense of satisfaction when it all starts working :)