For the last few weeks I've been understanding NKI, AWS's kernel language for writing custom ops on Trainium and Inferentia. just submitted my first kernel to the nki-samples repo: a decode-step attention kernel with GQA. A few things that clicked along the way: >Firstly, Decode attention is memory-bound, not compute-bound. Generating one token is tiny math (a single query), but it re-reads the entire growing kV cache every step. So the whole design is about touching K/V memory as few times as possible, not about FLOP.
> Grouped-query attention is a direct win here. When several query heads share one KV head, you load that K/V tile once and let the whole group ride on it. On a memory-bound kernel that saves exactly the thing that costs you.
>Online softmax is what lets you stream the KV cache in tiles instead of holding every logit at once. You carry a running max, denominator, and accumulator across tiles and rebase as each new tile arrives. Same answer, bounded memory.
> The hardware model is genuinely different coming from CUDA: matmul results can only exit through PSUM (a tiny accumulator), so you immediately evacuate to SBUF to free it for the next matmul and to let the softmax engines read it. It is validated on CPU against a NumPy reference, not yet on real Neuron hardware (just not yet).
> Next up: the split-KV flash-decoding variant for long context.
If you work on Neuron, NKI, or inference kernels, I'd love feedback on the approach.
#AWSNeuron #Trainium #Inferentia #MLSystems #NKI #KernelProgramming #Kernels #DecodeAttentionWithGQA