Add Kimi AttentionResiduals (AttnRes) kernelFeature/add attn res kernel#1161
Add Kimi AttentionResiduals (AttnRes) kernelFeature/add attn res kernel#1161kirsten-1 wants to merge 4 commits intolinkedin:mainfrom
Conversation
Implements Attention Residuals from Kimi/Moonshot AI (arxiv.org/abs/2603.15031). Replaces standard residual connections with softmax attention over depth blocks to solve PreNorm dilution problem. Features: - Single fused Triton kernel (RMSNorm + dot + softmax + weighted sum) - Supports N≤32 blocks with scores in registers - Includes forward/backward kernels and benchmark script Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| # ============================================================================ | ||
| # 正确性测试 | ||
| # ============================================================================ | ||
|
|
||
| def quick_test(): | ||
| print("Running correctness test...") | ||
| device = 'cuda' | ||
|
|
||
| configs = [ | ||
| (4, 2, 64, 4096, torch.float16, "N=4, D=4096, fp16"), | ||
| (8, 2, 64, 4096, torch.float16, "N=8, D=4096, fp16"), | ||
| (4, 2, 64, 4096, torch.bfloat16, "N=4, D=4096, bf16"), | ||
| (8, 2, 64, 8192, torch.float16, "N=8, D=8192, fp16"), | ||
| (4, 2, 64, 4096, torch.float32, "N=4, D=4096, fp32"), | ||
| ] | ||
|
|
||
| for N, B, T, D, dtype, name in configs: | ||
| V = torch.randn(N, B, T, D, device=device, dtype=dtype) | ||
| w_query = torch.randn(D, device=device, dtype=dtype) * 0.02 | ||
| w_norm = torch.ones(D, device=device, dtype=dtype) | ||
|
|
||
| ref = pytorch_attn_res(V, w_query, w_norm) | ||
| ours = LigerAttnResFunction.apply(V, w_query, w_norm, 1e-6) | ||
|
|
||
| diff = (ours.float() - ref.float()).abs().max().item() | ||
| tol = 1e-2 if dtype != torch.float32 else 1e-5 | ||
| status = "PASS" if diff < tol else "FAIL" | ||
| print(f" {name}: diff={diff:.2e} [{status}]") | ||
|
|
||
| print("Correctness test done!\n") | ||
|
|
||
|
|
||
| def backward_test(): | ||
| print("Running backward correctness test...") | ||
| device = 'cuda' | ||
|
|
||
| configs = [ | ||
| (4, 2, 64, 4096, torch.float16, "N=4, D=4096, fp16"), | ||
| (8, 2, 64, 4096, torch.float16, "N=8, D=4096, fp16"), | ||
| (4, 2, 64, 4096, torch.float32, "N=4, D=4096, fp32"), | ||
| ] | ||
|
|
||
| for N, B, T, D, dtype, name in configs: | ||
| V = torch.randn(N, B, T, D, device=device, dtype=dtype) | ||
| w_query = torch.randn(D, device=device, dtype=dtype) * 0.02 | ||
| w_norm = torch.ones(D, device=device, dtype=dtype) | ||
|
|
||
| # Reference | ||
| V_ref = V.clone().requires_grad_(True) | ||
| wq_ref = w_query.clone().requires_grad_(True) | ||
| wn_ref = w_norm.clone().requires_grad_(True) | ||
| h_ref = pytorch_attn_res(V_ref, wq_ref, wn_ref) | ||
| h_ref.sum().backward() | ||
|
|
||
| # Ours | ||
| V_ours = V.clone().requires_grad_(True) | ||
| wq_ours = w_query.clone().requires_grad_(True) | ||
| wn_ours = w_norm.clone().requires_grad_(True) | ||
| h_ours = LigerAttnResFunction.apply(V_ours, wq_ours, wn_ours, 1e-6) | ||
| h_ours.sum().backward() | ||
|
|
||
| dv_diff = (V_ours.grad.float() - V_ref.grad.float()).abs().max().item() | ||
| dwq_diff = (wq_ours.grad.float() - wq_ref.grad.float()).abs().max().item() | ||
| dwn_diff = (wn_ours.grad.float() - wn_ref.grad.float()).abs().max().item() | ||
| tol_v = 5e-2 if dtype != torch.float32 else 1e-3 | ||
| # dWq/dWn accumulate across all tokens via atomic_add; fp32 accum vs fp16 ref → large diff expected | ||
| tol_w = 1.0 if dtype != torch.float32 else 1e-3 | ||
| status = "PASS" if dv_diff < tol_v and dwq_diff < tol_w and dwn_diff < tol_w else "FAIL" | ||
| print(f" {name}: dV={dv_diff:.2e}, dWq={dwq_diff:.2e}, dWn={dwn_diff:.2e} [{status}]") | ||
|
|
||
| print("Backward test done!\n") |
There was a problem hiding this comment.
Keep correctness tests in pytest test suite (under test/transformers/) only, so we only have to maintain one test suite and avoid redundant tests in the benchmark ci
There was a problem hiding this comment.
Done! Moved to test/transformers/test_attn_res.py with pytest parametrize.
| # ============================================================================ | ||
| # 性能测试 | ||
| # ============================================================================ |
There was a problem hiding this comment.
Follow the benchmark guideline we recently added.
https://github.com/linkedin/Liger-Kernel/blob/main/benchmark/BENCHMARK_GUIDELINES.md
If there's anything unclear or you find problematic, any feedback is welcome to refine our documents for future contributors!
There was a problem hiding this comment.
Done! Rewrote following BENCHMARK_GUIDELINES.md.
src/liger_kernel/ops/attn_res.py
Outdated
| scores = tl.zeros((MAX_BLOCKS,), dtype=tl.float32) - 1e9 # -inf for unused | ||
| score_max = tl.full((), -1e9, dtype=tl.float32) |
There was a problem hiding this comment.
| scores = tl.zeros((MAX_BLOCKS,), dtype=tl.float32) - 1e9 # -inf for unused | |
| score_max = tl.full((), -1e9, dtype=tl.float32) | |
| scores = tl.zeros((MAX_BLOCKS,), dtype=tl.float32) + float("-inf") # -inf for unused | |
| score_max = tl.full((), float("-inf"), dtype=tl.float32) |
|
|
||
| # score = dot(w_query, k) | ||
| sc = tl.sum(w_query * k.to(tl.float32), axis=0) | ||
| scores = tl.where(tl.arange(0, MAX_BLOCKS) == i, sc, scores) |
There was a problem hiding this comment.
Just curious, is it a workaround for storing block score?
There was a problem hiding this comment.
Yes! Triton lacks scalar indexing into register-held vectors. The tl.where pattern compiles to a predicated move, keeping everything in registers.
Added a comment explaining this.
src/liger_kernel/ops/attn_res.py
Outdated
| # Store alpha for backward | ||
| for i in tl.static_range(0, MAX_BLOCKS): | ||
| if i < n_blocks: | ||
| a_i = tl.sum(tl.where(tl.arange(0, MAX_BLOCKS) == i, alpha, 0.0)) | ||
| tl.store(Alpha_ptr + i * n_tokens + tok, a_i) |
There was a problem hiding this comment.
If we turn alpha tensor layout into [B*T, N], we should be access alpha more efficiently. WDYT?
There was a problem hiding this comment.
Great suggestion! Changed Alpha and RSTD from [N, BT] to [BT, N].
src/liger_kernel/ops/attn_res.py
Outdated
| a_i = tl.sum(tl.where(tl.arange(0, MAX_BLOCKS) == i, alpha, 0.0)) | ||
| h += a_i * v | ||
|
|
||
| tl.store(Out_ptr + tok * D + cols, h.to(tl.load(V_ptr).dtype), mask=d_mask) |
There was a problem hiding this comment.
tl.store does type conversion implicitly for you, or you can infer dtype with pointer.dtype.element_ty instead
|
|
||
| da_i = tl.sum(dh * v, axis=0) | ||
| d_alpha = tl.where(tl.arange(0, MAX_BLOCKS) == i, da_i, d_alpha) | ||
| alpha = tl.where(tl.arange(0, MAX_BLOCKS) == i, a_i, alpha) |
There was a problem hiding this comment.
We should be able to load entire alpha with layout change mentioned above
There was a problem hiding this comment.
Good catch! Removed manual dtype conversion.
- Use float("-inf") instead of -1e9 for softmax initialization
- Change Alpha/RSTD layout from [N, B*T] to [B*T, N] for coalesced access
- Remove redundant manual dtype conversion in tl.store (implicit conversion)
- Add comment explaining tl.where workaround for register-held score storage
- Rewrite benchmark script following BENCHMARK_GUIDELINES.md format
- Move correctness tests to pytest test suite (test/transformers/test_attn_res.py)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
Implements Attention Residuals (AttnRes) from Kimi/Moonshot AI (arxiv.org/abs/2603.15031).
This PR addresses issue #1158.
What is AttnRes?
AttnRes replaces standard residual connections with softmax attention over depth blocks to solve the PreNorm dilution problem where deep layer contributions get diluted:
Implementation
Files Added
Test Plan
Benchmark Results
Tested on: NVIDIA GeForce RTX 5090, CUDA 12.8
Forward Pass Performance
Forward + Backward Performance
Correctness Tests
All tests passed with expected numerical precision:
Key Insights
The dramatic speedup is achieved by:
Closes #1158