3/10/2026 Training-Inference Parity in MoE Models: Where Numerics Drift
On this page Kernel fusions that are mathematically equivalent can still drift numerically. Here are the parity bugs we hit across both Kimi K2.5 serving and Qwen3.5-MoE training bring-up. When you train a model and serve it for inference, you expect them to agree. The same weights, the same input, the same output distribution. This training–inference numerical parity matters more than it sounds: For dense models, parity is relatively easy. Mixture-of-Experts models like Kimi K2.5, Qwen3.5-MoE, and DeepSeek V3 are harder. With routed experts, shared expert pathways, and all-reduce communication twice per layer across deep stacks, there are many places where "mathematically equivalent" optimizations produce numerically different results. This post catalogs the pitfalls we found. Each is a class of optimization that inference engines use for performance, but that can silently break numerical alignment. We found most of these while bringing up Kimi K2.5 on our serving stack, then saw the same failure mode again while debugging Qwen3.5-MoE. We will use FlashInfer and TRT-LLM style fused kernels as concrete examples. Every pitfall below reduces to one fact: floating-point addition is not associative. Even in FP32: (a + b) + c ≠ a + (b + c) Each addition rounds the result to the nearest representable value. Different orderings produce different intermediate values, which get different rounding errors. The errors are tiny per operation, but they compound through 61 transformer layers — and MoE routing amplifies them (a small change in hidden state can flip which experts get selected, cascading through the rest of the network). In tensor-parallel inference, every linear layer's output must be summed across GPUs via all-reduce. This happens twice per layer: after the attention output projection, and after the MLP/MoE. Training typically uses NCCL, which implements all-reduce as a reduce-scatter followed by an all-gather. In the reduce-scatter phase with a ring topology, the data is divided into chunks — one per GPU. Each chunk is accumulated as partial sums flow around the ring, starting from the GPU that "owns" that chunk. For 8 GPUs, this means different parts of the hidden vector see different summation orders: NCCL ring reduce-scatter (8 GPUs): chunk 0 (owned by GPU0): r0 + r7 + r6 + r5 + r4 + r3 + r2 + r1 chunk 1 (owned by GPU1): r1 + r0 + r7 + r6 + r5 + r4 + r3 + r2 chunk 2 (owned by GPU2): r2 + r1 + r0 + r7 + r6 + r5 + r4 + r3 ...each chunk starts from its owner, accumulates around the ring Inference serving engines often replace NCCL with custom all-reduce kernels for lower latency. FlashInfer's Lamport IPC kernel (derived from TRT-LLM) uses a different approach: each GPU writes its data to all other GPUs' buffers via CUDA IPC, then every GPU reads all contributions locally and sums them in a fixed order: Lamport kernel (all elements, on every GPU): every chunk: r0 + r1 + r2 + r3 + r4 + r5 + r6 + r7 Both…
