From SwiGLU Backward to INT8 Quantization: Notes from a KernelGen Challenge 9 Win
Published:
Official challenge page
Challenge 9: https://kernelgen.flagos.io/challenge/9?lang=zh&tab=readme
This note records our optimization process for KernelGen Challenge 9, whose function is named silu_dot_fwd_bwd_quant_fuse.
The offline final result was pleasantly dramatic: this task eventually placed first in the offline round. The part worth recording is the fairly plain sequence of constraints behind that result: understand the backward formula, respect BF16 quantization semantics, reduce repeated memory traffic, then split backend paths when one Triton kernel stopped being portable enough.
There is also a small warning hidden in the result. Some high-ranking attempts appeared to run into Wrong Answer cases later. Whether or not a score looks beautiful, this particular operator keeps asking the same question: did we preserve the byte-level quantization contract?
The Thread
The article first reconstructs the operator from SwiGLU backward and INT8 quantization. Then it explains why a 128x128 tile is natural for this problem, why the first large win came from fusing M-group and K-group quantization, and why the final implementation became a small collection of backend-specific Triton routes.
The Operator in Plain Terms
The function signature is:
def silu_dot_fwd_bwd_quant_fuse(
x,
grad_y,
grad_input_q,
grad_input_s,
y_q_t,
y_s_t,
group_size=128,
):
...
The input x has shape [M, 2H] and BF16 dtype. It is naturally split into two halves:
This is the common SwiGLU-style feed-forward block. During backward, instead of reading a saved activation y, the operator recomputes:
Given upstream gradient grad_y, the gradients are:
The output gradient is:
\[\operatorname{grad\_input}=[d_g,d_u]\in\mathbb{R}^{M\times 2H}.\]The task does not simply return grad_input and y. It quantizes both for downstream INT8 GEMMs:
| Output | Shape | Meaning |
|---|---|---|
grad_input_q | [M, 2H] INT8 | row-wise, per-128-channel quantized gradient |
grad_input_s | [M, 2H/128] FP32 | scales for grad_input_q |
y_q_t | [H, M] INT8 | transposed, per-128-token-group quantized y |
y_s_t | [H, M/128] FP32 | scales for y_q_t |
The benchmark shapes are fixed:
num_experts | tokens_per_expert | H |
|---|---|---|
| 8 | 128 | 2560 |
| 8 | 256 | 2560 |
| 16 | 128 | 2560 |
| 16 | 256 | 2560 |
| 32 | 128 | 2560 |
| 32 | 256 | 2560 |
| 8 | 128 | 4096 |
| 8 | 256 | 4096 |
| 16 | 128 | 4096 |
| 16 | 256 | 4096 |
| 32 | 128 | 4096 |
| 32 | 256 | 4096 |
Here $M=\text{num_experts}\times\text{tokens_per_expert}$.
The Quantization Contract
For a vector group $z$ of length 128, the reference quantizer is essentially:
\[s=\max\left(\frac{\max_i |z_i|}{127},10^{-10}\right), \qquad q_i=\operatorname{int8}\left(\operatorname{clip}\left(\frac{z_i}{s},-127,127\right)\right).\]For grad_input, the groups are channel chunks inside each row. For y_q_t, the groups are token chunks after transposition. This distinction is small in notation but important in memory layout:
grad_input: [M, 2H] -> quantize every row over channel groups of 128
y: [M, H] -> transpose to [H, M], quantize every channel over token groups of 128
The correctness check has two layers:
| Check | Tolerance | Consequence |
|---|---|---|
| FP32 scales | atol=1e-4, rtol=1e-5 | scale semantics cannot drift much |
| dequantized INT8 values | atol=0.25, rtol=0.25 | the downstream GEMM-facing values must match |
The small trap is BF16. The reference computes grad_input as BF16 before quantization and also quantizes y.to(bfloat16).to(float32). Skipping that BF16 roundtrip is tempting, and it is often faster, but it changes scales and quantized values. Several failed attempts came from underestimating this detail.
A Useful Algebraic Rewrite
The derivative can be written in the form used by the optimized kernels:
\[d_g=\operatorname{grad}_y\odot (u\sigma)\odot (1+g-\operatorname{silu}(g)).\]This saves one multiplication in the hot path once silu_val = g * sigmoid(g) is already available. It is a small example of the kind of optimization that only becomes visible after writing the math next to the kernel.
Why 128x128 Is the Natural Tile
The group size is 128. That number appears twice:
M-group quantization: 128 channels per group
K-group quantization: 128 tokens per group
So a tile with 128 tokens and 128 channels closes both loops at once. In one 128x128 tile, a Triton program can:
| Tile product | Reduction axis | Scale count |
|---|---|---|
d_gate | channels | 128 scales, one per token |
d_up | channels | 128 scales, one per token |
y | tokens | 128 scales, one per channel |
This gives a clean dataflow:
load gate/up/grad_y tile
-> sigmoid, silu, y, d_gate, d_up
-> row reductions for d_gate and d_up
-> column reduction for y
-> write three INT8 payloads and three scale vectors
The reference path materializes grad_input and y, then quantizes them through separate tensor operations. The optimized route tries to consume each tile while it is still close to the registers.
The First Large Win: Fuse the Two Quantizers
The early baseline was around 3.26x. The first major jump came from locating a memory-traffic problem: K-group quantization was either rereading x and recomputing y, or paying for a temporary y surface. A 2D fused Triton kernel removed that repeated path and lifted the average by about +4.61x.
The mechanism is concrete. The K-group quantizer needs y = silu(gate) * up; the fused tile keeps those values alive long enough to produce both:
M-group output: quantize [d_gate, d_up] by rows
K-group output: quantize y.T by token groups
A rough roofline estimate puts the fused kernel at roughly:
Quantity per 128x128 tile | Approximate value |
|---|---|
| BF16 reads | 3 * 128 * 128 * 2 bytes |
| INT8 writes | 3 * 128 * 128 bytes |
| scale writes | 3 * 128 * 4 bytes |
| operational intensity | about 2.2 FLOP/byte |
On an RTX 4090-like roofline, the ridge point is far higher than that. In plain language, this task is deeply memory-bound. Once that is clear, optimization becomes less like inventing clever arithmetic and more like avoiding unnecessary trips through memory.
A Small Line With a Large Effect
One of the surprisingly valuable changes was:
# less helpful
q = x / (absmax / 127.0)
# better
q = x * (127.0 / absmax)
This change came from treating division as a measured bottleneck, not as a cosmetic algebra rewrite. It raised the average speedup by about +0.92x. The lesson is modest but useful: tensor division inside Triton should not be assumed cheap, and the compiler may not always rewrite this expression in the way we want across all backends.
This puts the high-value changes in a simple cost-model frame:
| Change | Cost reduced |
|---|---|
| 2D fused tile | rereads and temporary tensors |
| channel-major grid | poorer L2 locality |
| division-to-multiply | slow elementwise division |
| y-first quantization | peak register lifetime |
| platform split | backend-specific codegen and cache behavior |
The Platform Split
A single beautiful Triton kernel was not the final shape. The platform matrix pushed us toward a more ordinary engineering answer: share the mathematical contract, split the backend paths.
The current implementation detects the backend and dispatches to separate routes:
| Platform route | Implementation idea |
|---|---|
| NVIDIA | dedicated 2D kernel, y-first quantization, num_warps=16, num_stages=1 |
| Hygon/AMD | 2D grid, BF16 quantization, no NVIDIA-style cache hints, num_warps=8 |
| TianShu | 2D grid, early BF16 y, deeper staging, num_stages=3 |
| MetaX | 2D grid with FP32 absmax/BF16-round scale path |
| T-Head | conservative generic fused route, tuned separately after platform recovery |
| MooreThreads | M-group kernel plus y buffer plus K-group buffer quantization |
| Ascend | Triton forward/backward recompute, then PyTorch quantization for correctness |
This table is less elegant than a single abstraction, but it matches the competition reality. The backends differ in warp size, BF16 behavior, cache policy, register pressure, and Triton implementation maturity. A win on one platform could easily be a regression or a Wrong Answer on another.
Final Dispatch Matters
Several experimental kernels can appear during development, but they are not necessarily used in the final dispatch. When reading optimization code, the important object is the branch that the submitted function actually launches.
MooreThreads Was Its Own Puzzle
MooreThreads was a good reminder that coalescing alone is not a sufficient explanation of performance. Several attempts tried to make the K-group path look more like the 2D fused CUDA path, but extra buffers or larger fused tiles often lost to bandwidth and register pressure.
The route that survived in the current implementation is more pragmatic:
M-group kernel:
recompute forward/backward
quantize grad_input
write BF16 y buffer
K-group buffer kernel:
read 128x128 BF16 y tiles
reduce over tokens
write y_q_t and y_s_t
This is not the purest fusion story, but it preserves enough locality while keeping each kernel’s register pressure manageable. The final tuning also moved simple kernels toward num_stages=1, because there is no deep inner loop that benefits much from software pipelining.
Correctness Boundaries That Shaped the Work
Several failed attempts were useful because they drew the boundary of the problem:
| Attempt | Outcome |
|---|---|
| skip BF16 roundtrip | scale/quantization mismatches |
| use FP16 intermediate compute on AMD | numerical path drift |
| use PyTorch quantization on general GPU path | launch and framework overhead dominated |
| rely on autotune | T-Head timeout exceeded the platform limit |
| increase work per program too much | fewer blocks and more spills hurt occupancy |
| add temporary buffers broadly | extra memory traffic erased arithmetic savings |
This is also why the offline win felt satisfying. It was not only a matter of chasing the largest visible speedup. A high-speed implementation still had to survive the hidden shape/correctness surface, and the quantization contract made that surface fairly sharp.
Problem, Fix, and Payoff
The useful way to read the optimization history is as a sequence of bottlenecks we isolated:
| Bottleneck located | Fix | Payoff |
|---|---|---|
PyTorch/reference surface materialized grad_input and y | Move recompute and quantization into Triton | From about 3.26x baseline into the kernel-optimized range |
K-group quantization reread x or needed a temporary y surface | Fuse M-group and K-group quantization inside one 128x128 tile | About +4.61x, reaching roughly 7.87x |
| Tile order gave weak L2 locality | Use a channel-major grid | Around +0.39x in the measured route |
| Inner-loop quantization still used elementwise division | Rewrite x / (absmax / 127) as x * (127 / absmax) | About +0.92x |
| Backend differences caused performance and correctness regressions | Keep the mathematical contract shared, but split platform dispatch | Stable 7/7 submissions around 10.46x; offline final first place |
The exact numbers come from submission records and our review notes. Their role here is to mark which cost each change removed, so the narrative stays focused on technical decisions.
Lessons Worth Keeping
The main lesson is that this operator looks like activation math, but behaves like a quantization and memory-layout problem.
For future kernels, I would keep the following checklist:
| Question | Why it matters |
|---|---|
| What intermediate tensor is largest? | It is probably the first thing to avoid materializing |
| Are two reductions using the same group size? | This may reveal the natural tile |
| Does the reference round to BF16 before quantization? | It determines both scales and INT8 values |
| Is division still present in the inner loop? | It may be an avoidable throughput sink |
| Does a platform need its own route? | Multi-backend Triton is not one architecture |
| Did a change pass all hidden-like shapes? | A fast Wrong Answer is still a failed kernel |
If I compress the work into one sentence, it would be this: the winning route was to align the tile with the quantization groups, keep the data close while producing both quantized outputs, and stop pretending that all seven backends wanted the same kernel.

Leave a Comment