From SwiGLU Backward to INT8 Quantization: Notes from a KernelGen Challenge 9 Win

16 minute read

Published:

This note records our optimization process for KernelGen Challenge 9, whose function is named silu_dot_fwd_bwd_quant_fuse.

The offline final result was pleasantly dramatic: this task eventually placed first in the offline round. The part worth recording is the fairly plain sequence of constraints behind that result: understand the backward formula, respect BF16 quantization semantics, reduce repeated memory traffic, then split backend paths when one Triton kernel stopped being portable enough.

There is also a small warning hidden in the result. Some high-ranking attempts appeared to run into Wrong Answer cases later. Whether or not a score looks beautiful, this particular operator keeps asking the same question: did we preserve the byte-level quantization contract?

The Thread

The article first reconstructs the operator from SwiGLU backward and INT8 quantization. Then it explains why a 128x128 tile is natural for this problem, why the first large win came from fusing M-group and K-group quantization, and why the final implementation became a small collection of backend-specific Triton routes.

The Operator in Plain Terms

The function signature is:

def silu_dot_fwd_bwd_quant_fuse(
    x,
    grad_y,
    grad_input_q,
    grad_input_s,
    y_q_t,
    y_s_t,
    group_size=128,
):
    ...

The input x has shape [M, 2H] and BF16 dtype. It is naturally split into two halves:

\[x = [g, u], \qquad g,u\in\mathbb{R}^{M\times H}.\]

This is the common SwiGLU-style feed-forward block. During backward, instead of reading a saved activation y, the operator recomputes:

\[\sigma = \operatorname{sigmoid}(g),\qquad \operatorname{silu}(g)=g\sigma,\qquad y=\operatorname{silu}(g)\odot u.\]

Given upstream gradient grad_y, the gradients are:

\[d_u = \operatorname{grad}_y\odot \operatorname{silu}(g),\] \[d_g =\operatorname{grad}_y\odot u\odot\sigma\odot \big(1+g(1-\sigma)\big).\]

The output gradient is:

\[\operatorname{grad\_input}=[d_g,d_u]\in\mathbb{R}^{M\times 2H}.\]

The task does not simply return grad_input and y. It quantizes both for downstream INT8 GEMMs:

OutputShapeMeaning
grad_input_q[M, 2H] INT8row-wise, per-128-channel quantized gradient
grad_input_s[M, 2H/128] FP32scales for grad_input_q
y_q_t[H, M] INT8transposed, per-128-token-group quantized y
y_s_t[H, M/128] FP32scales for y_q_t

The benchmark shapes are fixed:

num_expertstokens_per_expertH
81282560
82562560
161282560
162562560
321282560
322562560
81284096
82564096
161284096
162564096
321284096
322564096

Here $M=\text{num_experts}\times\text{tokens_per_expert}$.

The Quantization Contract

For a vector group $z$ of length 128, the reference quantizer is essentially:

\[s=\max\left(\frac{\max_i |z_i|}{127},10^{-10}\right), \qquad q_i=\operatorname{int8}\left(\operatorname{clip}\left(\frac{z_i}{s},-127,127\right)\right).\]

For grad_input, the groups are channel chunks inside each row. For y_q_t, the groups are token chunks after transposition. This distinction is small in notation but important in memory layout:

grad_input: [M, 2H]  -> quantize every row over channel groups of 128
y:          [M, H]   -> transpose to [H, M], quantize every channel over token groups of 128

The correctness check has two layers:

CheckToleranceConsequence
FP32 scalesatol=1e-4, rtol=1e-5scale semantics cannot drift much
dequantized INT8 valuesatol=0.25, rtol=0.25the downstream GEMM-facing values must match

The small trap is BF16. The reference computes grad_input as BF16 before quantization and also quantizes y.to(bfloat16).to(float32). Skipping that BF16 roundtrip is tempting, and it is often faster, but it changes scales and quantized values. Several failed attempts came from underestimating this detail.

A Useful Algebraic Rewrite

The derivative can be written in the form used by the optimized kernels:

\[d_g=\operatorname{grad}_y\odot (u\sigma)\odot (1+g-\operatorname{silu}(g)).\]

This saves one multiplication in the hot path once silu_val = g * sigmoid(g) is already available. It is a small example of the kind of optimization that only becomes visible after writing the math next to the kernel.

Why 128x128 Is the Natural Tile

The group size is 128. That number appears twice:

M-group quantization: 128 channels per group
K-group quantization: 128 tokens per group

So a tile with 128 tokens and 128 channels closes both loops at once. In one 128x128 tile, a Triton program can:

Tile productReduction axisScale count
d_gatechannels128 scales, one per token
d_upchannels128 scales, one per token
ytokens128 scales, one per channel

This gives a clean dataflow:

load gate/up/grad_y tile
    -> sigmoid, silu, y, d_gate, d_up
    -> row reductions for d_gate and d_up
    -> column reduction for y
    -> write three INT8 payloads and three scale vectors

The reference path materializes grad_input and y, then quantizes them through separate tensor operations. The optimized route tries to consume each tile while it is still close to the registers.

The First Large Win: Fuse the Two Quantizers

The early baseline was around 3.26x. The first major jump came from locating a memory-traffic problem: K-group quantization was either rereading x and recomputing y, or paying for a temporary y surface. A 2D fused Triton kernel removed that repeated path and lifted the average by about +4.61x.

The mechanism is concrete. The K-group quantizer needs y = silu(gate) * up; the fused tile keeps those values alive long enough to produce both:

M-group output: quantize [d_gate, d_up] by rows
K-group output: quantize y.T by token groups

A rough roofline estimate puts the fused kernel at roughly:

Quantity per 128x128 tileApproximate value
BF16 reads3 * 128 * 128 * 2 bytes
INT8 writes3 * 128 * 128 bytes
scale writes3 * 128 * 4 bytes
operational intensityabout 2.2 FLOP/byte

On an RTX 4090-like roofline, the ridge point is far higher than that. In plain language, this task is deeply memory-bound. Once that is clear, optimization becomes less like inventing clever arithmetic and more like avoiding unnecessary trips through memory.

A Small Line With a Large Effect

One of the surprisingly valuable changes was:

# less helpful
q = x / (absmax / 127.0)

# better
q = x * (127.0 / absmax)

This change came from treating division as a measured bottleneck, not as a cosmetic algebra rewrite. It raised the average speedup by about +0.92x. The lesson is modest but useful: tensor division inside Triton should not be assumed cheap, and the compiler may not always rewrite this expression in the way we want across all backends.

This puts the high-value changes in a simple cost-model frame:

ChangeCost reduced
2D fused tilerereads and temporary tensors
channel-major gridpoorer L2 locality
division-to-multiplyslow elementwise division
y-first quantizationpeak register lifetime
platform splitbackend-specific codegen and cache behavior

The Platform Split

A single beautiful Triton kernel was not the final shape. The platform matrix pushed us toward a more ordinary engineering answer: share the mathematical contract, split the backend paths.

The current implementation detects the backend and dispatches to separate routes:

Platform routeImplementation idea
NVIDIAdedicated 2D kernel, y-first quantization, num_warps=16, num_stages=1
Hygon/AMD2D grid, BF16 quantization, no NVIDIA-style cache hints, num_warps=8
TianShu2D grid, early BF16 y, deeper staging, num_stages=3
MetaX2D grid with FP32 absmax/BF16-round scale path
T-Headconservative generic fused route, tuned separately after platform recovery
MooreThreadsM-group kernel plus y buffer plus K-group buffer quantization
AscendTriton forward/backward recompute, then PyTorch quantization for correctness

This table is less elegant than a single abstraction, but it matches the competition reality. The backends differ in warp size, BF16 behavior, cache policy, register pressure, and Triton implementation maturity. A win on one platform could easily be a regression or a Wrong Answer on another.

Final Dispatch Matters

Several experimental kernels can appear during development, but they are not necessarily used in the final dispatch. When reading optimization code, the important object is the branch that the submitted function actually launches.

MooreThreads Was Its Own Puzzle

MooreThreads was a good reminder that coalescing alone is not a sufficient explanation of performance. Several attempts tried to make the K-group path look more like the 2D fused CUDA path, but extra buffers or larger fused tiles often lost to bandwidth and register pressure.

The route that survived in the current implementation is more pragmatic:

M-group kernel:
    recompute forward/backward
    quantize grad_input
    write BF16 y buffer

K-group buffer kernel:
    read 128x128 BF16 y tiles
    reduce over tokens
    write y_q_t and y_s_t

This is not the purest fusion story, but it preserves enough locality while keeping each kernel’s register pressure manageable. The final tuning also moved simple kernels toward num_stages=1, because there is no deep inner loop that benefits much from software pipelining.

Correctness Boundaries That Shaped the Work

Several failed attempts were useful because they drew the boundary of the problem:

AttemptOutcome
skip BF16 roundtripscale/quantization mismatches
use FP16 intermediate compute on AMDnumerical path drift
use PyTorch quantization on general GPU pathlaunch and framework overhead dominated
rely on autotuneT-Head timeout exceeded the platform limit
increase work per program too muchfewer blocks and more spills hurt occupancy
add temporary buffers broadlyextra memory traffic erased arithmetic savings

This is also why the offline win felt satisfying. It was not only a matter of chasing the largest visible speedup. A high-speed implementation still had to survive the hidden shape/correctness surface, and the quantization contract made that surface fairly sharp.

Problem, Fix, and Payoff

The useful way to read the optimization history is as a sequence of bottlenecks we isolated:

Bottleneck locatedFixPayoff
PyTorch/reference surface materialized grad_input and yMove recompute and quantization into TritonFrom about 3.26x baseline into the kernel-optimized range
K-group quantization reread x or needed a temporary y surfaceFuse M-group and K-group quantization inside one 128x128 tileAbout +4.61x, reaching roughly 7.87x
Tile order gave weak L2 localityUse a channel-major gridAround +0.39x in the measured route
Inner-loop quantization still used elementwise divisionRewrite x / (absmax / 127) as x * (127 / absmax)About +0.92x
Backend differences caused performance and correctness regressionsKeep the mathematical contract shared, but split platform dispatchStable 7/7 submissions around 10.46x; offline final first place

The exact numbers come from submission records and our review notes. Their role here is to mark which cost each change removed, so the narrative stays focused on technical decisions.

Lessons Worth Keeping

The main lesson is that this operator looks like activation math, but behaves like a quantization and memory-layout problem.

For future kernels, I would keep the following checklist:

QuestionWhy it matters
What intermediate tensor is largest?It is probably the first thing to avoid materializing
Are two reductions using the same group size?This may reveal the natural tile
Does the reference round to BF16 before quantization?It determines both scales and INT8 values
Is division still present in the inner loop?It may be an avoidable throughput sink
Does a platform need its own route?Multi-backend Triton is not one architecture
Did a change pass all hidden-like shapes?A fast Wrong Answer is still a failed kernel

If I compress the work into one sentence, it would be this: the winning route was to align the tile with the quantization groups, keep the data close while producing both quantized outputs, and stop pretending that all seven backends wanted the same kernel.

Leave a Comment

LinkedIn QQ空间 知乎