Generating State-of-the-Art GEMMs with TorchInductor’s CuteDSL backend
Introduction TorchInductor currently supports three autotuning backends for matrix multiplications: Triton, CUTLASS (C++), and cuBLAS. This post describes the integration of CuteDSL as a fourth backend, the technical motivation for the work, and the performance results observed so far. The kernel-writing DSL space has gained significant momentum, with Triton, Helion, Gluon, CuTile, and CuteDSL each occupying a different point in the abstraction-performance tradeoff. When evaluating whether to integrate a new backend into TorchInductor, we apply three criteria: (1) the integration does not impose a large maintenance burden on our team, or there is a long-term committed effort from the vendor; (2) it does not regress compile time or benchmarking time relative to existing backends; and (3) it delivers better performance on target workloads. CuteDSL satisfies all three. NVIDIA is actively developing CuteDSL and provides optimized kernel templates, which limits the maintenance burden on TorchInductor. Compile times are at parity with our other backends, a significant improvement over the CUTLASS C++ path which requires full nvcc invocations. Beyond these immediate benefits, CuteDSL represents a longer-term strategic investment. It is built on the same abstractions as CUTLASS C++, which has demonstrated strong performance on FP8 GEMMs and epilogue fusion, but it is written in Python, has faster compile times, and is less complex to maintain. As NVIDIA continues to invest in CuteDSL performance, CuteDSL is positioned to serve as an eventual replacement for the CUTLASS C++ integration on newer hardware generations, simplifying the TorchInductor codebase. The combination of aligned incentives, growing open-source adoption (Tri Dao’s Quack library, Jay Shah at Colfax International), and a lower-level programming model that exposes the full thread and memory hierarchy makes CuteDSL a well-positioned backend for delivering optimal GEMM performance on current and future NVIDIA hardware. Strategy: Why We Target GEMMs Not all operations benefit equally from a new backend. For memory-bound operations — elementwise math, activations, and reductions— Triton already generates high-quality code. Its block-level programming model is well-suited to these workloads which only require vectorized memory accesses, and the performance gap between Triton and hand-written kernels is small. CuteDSL can express pointwise operations and reductions, but due to its low-level nature, automatically generating CuteDSL kernels from scratch is complex. In practice the two DSLs produce kernels that perform comparably on these workloads, so this extra complexity would not provide any benefit. Our own experiments are shown below which validate this theory. We ran a triton and cuteDSL softmax kernel on progressively larger input sizes – both approach terminal bandwidth on GB200. GEMMs are a different story. Matrix multiplications dominate the compute profile of transformer-based models: in a typical LLM forward pass, GEMMs in the attention projections, FFN layers, and output head account for the majority of GPU cycles. Achieving near-peak utilization on these operations requires precise control over the hardware features that each new GPU generation introduces — tile sizes tuned to the tensor core pipeline, explicit management of shared memory staging, warp-level scheduling, and on newer architectures like B200, thread block clusters and distributed…

