tw: cursed floating-point numerics and inductive bias story
While porting my connect 4 convnet to run on my ancient sandy bridge thinkpad I noticed the model didn't play quite right. The cpu version of the model chose to play its first move at the edge of the board whereas the desktop gpu (3070) version of the model would start as expected in the center. This wasn't the model being completely broken, as its following moves were sane, and it wasn't due to random sampling as I sampled from it with argmax/t=0.
However, the output probabilities were entirely different between the cpu and 3070 versions for an empty board. The 3070 version would have basically 1.0 at the center of the board and ~1e-12 everywhere else while the cpu version had around 1e-1 everywhere, indicating it was "lost." However, on the second turn, the cpu model would start producing confident probability distributions again.
Even though my thinkpad would spew a pile of "Could not initialize NNPACK!" warnings at init, this wasn't a first-iteration/init bug. Indeed, running the same empty board through the model multiple times didn't help and still produced the confused distribution.
Some print debugging revealed that the conv-instancenorm-relu backbone of the model was producing all zeros on cpu while producing large nonzero values on the 3070. On the cpu, every block would output zeros, while on the 3070 small ~1e-7 values would propagate and be amplified, block by block.
The interesting realization is that cpu is actually _more_ correct here! When all inputs are zeros, the output for every spatial position is just the bias of the convolution for a given channel. As instance norm is applied per-channel, the result due to mean-subtraction should be zero. Pop quiz: what does
a = torch.nn.InstanceNorm2d(128)
a(torch.full((1, 128, 6, 7), 0.5)) return? It turns out it's implementation dependent in a few ways!
For constant fill values that are powers of two, such as 0.5, 1.0, and importantly 0.0, many implementations expectedly return a tensor of all zeros. However, for non-powers of two (such as the per-channel biases of a conv layer), you might see a small value around 1e-6 or 1e-7. Still close to zero but enough to be amplified across multiple layers.
Okay, is this just a special case of cascading errors? Why in particular would some zero/small nonzero values be important for model performance? For that we can turn to one of my all-time favorite convnet papers "How Much Positional Information Do Convolutional Neural Networks Encode" by Islam et al., which posits that convnets do encode position information after inferring it from zero-padding at the borders of images (or in my case the game board).
I had used (1,1) zero-padding to conveniently preserve spatial dimensions for 3x3 convolutions. This same padding allowed the model to center itself (literally) but was now useless when given an empty board and a _more_ accurate InstanceNorm implementation!
Indeed, sprinkling a tiny bit of noise or even a constant value of 1e-6 to the board state restores the model's expected behavior of a confident center move to open. The output probabilities matched closely with that of the gpu version as well for subsequent moves.
tldr: know the literal edge cases of your numerical formats and inductive biases
also fun fact the cpu implementation of InstanceNorm seems to yield zeros for non-powers of two constant inputs but the implementation on my 5600X cpu does not (!!)