SOTA Normalization Performance with torch.compile
Introduction Normalization methods (LayerNorm/RMSNorm) are foundational in deep learning and are used to normalize values of inputs to result in a smoother training process for deep learning models. We evaluate and improve torch.compile performance for LayerNorm/RMSNorm on NVIDIA H100 and B200 to reach near SOTA performance on a kernel-by-kernel basis, in addition with further speedups through automatic fusion capabilities. Forwards LayerNorm LayerNorm was first introduced in this paper: https://arxiv.org/abs/1607.06450. It normalizes the inputs by taking the mean and variance, along with scaling by learnable parameters, gamma (weight) and Beta (bias). RMSNorm RMSNorm (root mean square norm) was introduced as a follow up of LayerNorm in this paper: https://arxiv.org/abs/1910.07467. Instead of centering on the mean, the RMS is used to normalize, which is a sum of the squares of x values. We still use gamma (weight) as a learnable parameter for scaling, although there is no longer a bias term. The forward pass for both LayerNorm and RMSNorm are relatively similar, typically with a reduction across the contiguous dimension and some extra pointwise ops, with RMSNorm typically being a bit more efficient as there are fewer flops and no bias. For the purposes of this study, we present benchmark results among LayerNorm and RMSNorm interchangeably given the similarity of the kernels. Quack Quack is a library of hyper optimized CuteDSL kernels from Tri Dao: https://github.com/Dao-AILab/quack. Their current README shows on H100 how Quack outperforms torch.compile for these reduction kernels. We use Quack as the SOTA baseline of which we evaluate the performance of torch.compile on. Quack’s README showcases previous results from torch.compile performance below, of which it can be observed that torch.compile ~50% of Quack performance typically. torch.compile Below we illustrate the general logic of a torch.compile generated kernel for LayerNorm forwards, with the same approach for RMSNorm). We assume that the input reduction dimension (rnumel) is contiguous, which we refer to in Inductor as an Inner reduction. While the kernel might look a bit confusing, what’s actually happening is very simple: - Maintain partial sums of size R_BLOCK for each row in X the input - Use partial sums to calculate mean and variance - Apply elementwise to X based on layernorm formula - Store output of elementwise - Store mean and variance if elementwise_affine=True and requires_grad=True for backwards As a side note, if R is smaller than some heuristic (1024), then Inductor generates a persistent reduction, where we no longer need to loop over the r dimension. Instead, we go directly to taking the mean. In comparing the torch.compile vs Quack versions of RMSNorm forwards, we can reproduce the poor performance of torch.compile compared to Quack on H100 and B200. However, after autotuning and using that to motivate Inductor defaults, we arrive at SOTA performance on H100 and B200. In general, the following was done to achieve this result: - Inserting torch._dynamo.reset() during benchmarking – makes sure that torch.compile does not use automatic dynamic shapes, as previously a torch.compile call per shape was performed, making the compiler assume dynamic…

