Triton 07: Flash Attention 3 — Triton으로 어디까지 가능한가
Triton 06: Flash Attention 2까지 와서 A100에서 PyTorch 대비 causal 22× 가속을 확인했다. 그 다음은 FA3다.
FlashAttention-3 논문은 Hopper의 비동기·저정밀 하드웨어를 활용해 H100에서 FA2 대비 추가 1.5–2.0× 가속을 만든다. 그런데 FA3의 핵심 기법은 대부분 Hopper 전용 + CUDA로만 표현 가능한 것들이다. 그럼 Triton으로는 어디까지 닿을 수 있을까?
FA3 논문 자체의 원리가 궁금하다면 FlashAttention-3 논문 리뷰를 먼저 읽는 것을 추천한다.
결론 미리: A100 + Triton에서 FA2 대비 추가 ~3–5% 가속이 한계다. Hopper 전용 기능(TMA, wgmma, 비동기 producer-consumer)은 Triton으로 표현이 어렵거나 불가능하기 때문이다. 공식 FA3의 진짜 가치는 H100 + CUDA에서만 살아난다는 사실을 실측으로 확인했다.
FA3의 7가지 핵심 기법 — Triton에서 가능한가?
FA3 논문이 말하는 가속의 출처를 정리하면 다음과 같다.
| # | 기법 | 효과 (논문) | Triton 표현 | 본 구현 적용 |
|---|---|---|---|---|
| 1 | Producer-consumer warp specialization | ~30% | ✗ (warp 추상화 너머) | 미적용 |
| 2 | Inter-warpgroup ping-pong (GEMM↔softmax) | ~15–20% | △ (tl.async_task preview) | 미적용 |
| 3 | TMA (비동기 메모리 복사) | H100에서 10–15% | △ (tl.make_tensor_descriptor preview) | 미적용 |
| 4 | wgmma (warpgroup matmul) | 자동 | ✓ (tl.dot이 H100에서 자동 사용) | 적용됨 |
| 5 | FP8 + incoherent processing | H100에서 1.5–2× | △ (block scaling 손수 작성) | 미적용 |
| 6 | Persistent kernel (NUM_SMS launch) | launch 절감 + occupancy | ✓ | 시도→실패 |
| 7 | Wider autotune | 1–5% | ✓ | 적용됨 |
Triton으로 즉시 가능한 것은 4·6·7 정도다. 1·2는 H100에서도 Triton으로 표현이 어렵고, 3은 experimental, 5는 매우 복잡하다. 즉 Triton FA3는 본질적으로 “FA2 + 스케줄링 튜닝” 의 영역에 머문다.
Method: Triton FA3에 적용한 변경
알고리즘 코어는 FA2와 동일하다 — un-scaled 누적, exp2, STAGE 1/2 분기, tl.dot accumulator. 그 위에 다음을 얹었다.
1. 확장된 autotune 탐색 공간
FA2는 6개 config를 탐색했다. FA3는 17개로 늘리고, BLOCK_M ≤ 256 + num_stages ≤ 6까지 포함한다.
문제는 head_dim=128 + BLOCK_M=256처럼 SRAM을 초과하는 조합이다. tl.arange가 power-of-2만 받는 제약과 별개로, 이런 config는 컴파일 시 OOM으로 실패한다.
2. SRAM-aware early pruning
prune_configs_by로 컴파일 전에 부적합한 config를 제거한다.
이 한 가지로 head_dim=128에서 BLOCK_M=256 케이스가 자동 차단되고, autotune이 실제로 fit하는 config만 프로파일링한다.
3. 명시적 fp32 누적
tl.dot은 기본적으로 fp32 accumulator를 쓰지만, out_dtype을 명시해 FA3 의도(저정밀 입력 + fp32 누적)를 분명히 했다.
H100의 wgmma 명령은 이 tl.dot이 자동으로 사용한다 — 별도 코드 없이 dot만 부르면 된다.
4. 1D grid + (bh, m) decomposition
FA2는 (num_m_tiles, bh) 2D grid였다. FA3는 1D grid (num_m_tiles · bh,)로 단순화한다.
같은 bh의 인접 Q 타일이 연속된 program ID를 받으므로 L2 cache 친화적인 SM 매핑이 자연스럽게 만들어진다. Triton의 SM 스케줄러는 인접 program을 같은 SM 또는 인접 SM으로 보내는 경향이 있어, K/V 재사용률이 올라간다.
Persistent kernel: 시도했으나 실패한 이야기
FA3 논문의 6번 기법 — persistent kernel을 시도해봤다. 결과부터: A100에서는 손해다.
기존 grid (num_m_tiles · bh,) 대신 (min(NUM_SMS, total_tiles),)로 launch하고, 각 program이 내부 루프로 자기 몫의 타일을 순회한다.
A100-SXM4-80GB, fp16, 4 GPU 평균 (num_heads=16, head_dim=64, causal=True):
| seq | FA3 (1D grid) | FA3-Persistent | 차이 |
|---|---|---|---|
| 4096 | 0.350 | 0.388 | -11% |
| 8192 | 0.995 | 1.164 | -17% |
| 16384 | 3.441 | 3.778 | -10% |
| 32768 | 12.911 | 13.595 | -5% |
긴 seq에서도 1D grid가 더 빠르다. persistent가 평균 5–17% 손해를 본다.
원인은 work imbalance 다.
A100 NUM_SMS = 108
seq=4096, bh=16, BLOCK_M=128 → num_tiles = 32 × 16 = 512 tiles
기존 grid: (512,) → 모든 SM이 work 보유, 하드웨어 스케줄러가 밸런싱
Persistent grid: (108,) → 각 SM이 ~5 tiles 직렬 처리
launch 절감 < 직렬 비율 증가로 인한 latency
Triton의 1D grid는 SM 스케줄러가 동적으로 분배하는 반면, persistent는 wave 단위 직렬화가 강제된다. seq가 길어 tile 수가 충분하면 그냥 1D grid가 더 좋다.
그럼 공식 FA3는 왜 persistent로 이득을 보는가? H100 환경에서 persistent가 이득이 되는 이유는 단순히 launch 절감이 아니다. Persistent 자체로는 위 실험처럼 손해다. 진짜 이득의 출처는 다음 셋이다.
- wgmma의 큰 타일 (BLOCK_M=192) → 같은 seq에서 tile 수가 줄어 직렬화 비율이 낮아짐
- Producer-consumer warp split + MBARRIER → 같은 SM 안에서 GEMM ↔ softmax 비동기 오버랩
- TMA 비동기 복사 → 메모리 latency를 compute로 가림
즉 persistent는 나머지 셋의 base 역할일 뿐, 단독으로는 의미가 없다. Triton에서는 이 셋 모두 표현이 어려우므로, persistent를 켜면 base만 남고 효과가 사라진다 — 오히려 손해.
이 이유로 본 구현에서는 flash_attention_v3_persistent를 제거했다. 향후 H100 + Triton TMA 지원이 더 발전한 뒤 재시도할 가치가 있다.
벤치마크 결과 (A100-SXM4-80GB × 4)
num_heads=16, fp16 · 4 GPU 평균 (표준편차 < 1%) · 11회 측정 중 첫 회 폐기.
Causal forward, head_dim=64
| seq | FA1 | FA2 | FA3 | PyTorch | FA3/FA2 | FA3/PT |
|---|---|---|---|---|---|---|
| 4096 | 0.571 | 0.361 | 0.350 | 5.243 | 1.03× | 14.97× |
| 8192 | 1.721 | 1.033 | 0.992 | 21.807 | 1.04× | 21.98× |
| 16384 | 5.972 | 3.556 | 3.391 | 70.856 | 1.05× | 20.90× |
| 32768 | 22.247 | 13.426 | 12.847 | OOM | 1.05× | — |
Causal forward, head_dim=128 (Llama/Qwen 표준)
| seq | FA1 | FA2 | FA3 | FA3/FA2 |
|---|---|---|---|---|
| 2048 | 0.374 | 0.257 | 0.245 | 1.05× |
| 4096 | 1.113 | 0.587 | 0.579 | 1.01× |
| 8192 | 3.905 | 1.824 | 1.763 | 1.03× |
| 32768 | 57.620 | 24.670 | 24.328 | 1.01× |
핵심 관찰
- causal + 긴 seq에서 일관되게 3–5% 추가 가속 — 확장 autotune이 더 큰 BLOCK_M·num_warps=8·num_stages=5를 선택해 SRAM 점유율이 좋아진 결과
- non-causal과 짧은 seq에서는 사실상 FA2와 동일 — 해당 케이스에서 best config가 같은 config로 수렴
- causal seq=8192에서 FA3/FA2 = 1.04×, FA3/PT = 21.98× 피크 — 알고리즘적 한계에 거의 닿았음을 의미
- fwd+bwd는 backward를 FA2 그대로 재사용했으므로 거의 동일 (FA3의 backward 개선은 H100 wgmma 와만 결합)
- GPU 0~3 측정값 표준편차는 평균의 1% 미만 — Triton autotune이 4개 프로세스에서 각각 동일한 best config로 수렴
솔직한 한계
- Hopper 전용 기능을 활용하지 않으므로 H100에서 측정해도 본 구현은 큰 이득이 없다. 진짜 FA3 가속은 TMA + wgmma + producer-consumer 가 결합될 때 나타난다.
- FP8 은 본 구현에 미포함. block scaling + Hadamard 변환은 Triton으로 표현 가능하지만 별도의 큰 작업이 필요하다.
- Backward 는 FA2와 동일한 3-stage 커널을 그대로 재사용. FA3 논문의 backward 개선도 wgmma 의존이 크다.
향후 Hopper 환경 확보 시 시도해보고 싶은 항목
본 구현은 A100에서만 검증되었다. H100 partition을 안정적으로 사용할 수 있게 되면 다음을 추가로 시도할 계획이다 — 결과는 별도 포스트로 정리하겠다.
- TMA descriptor (
tl.make_tensor_descriptor) — Hopper 전용 비동기 메모리 복사. K/V prefetch를 compute와 겹쳐 메모리 latency를 가리는 방식. H100에서만 동작하고 Triton 3.x에서 preview 단계. - FP8 + block scaling — FlashAttention-3 논문의 핵심 가속 출처 중 하나. E4M3 입력에 per-block scale factor를 곱하고 Hadamard 변환으로 outlier를 분산시킨다. H100 wgmma fp8 명령어가 필요하므로 A100에서는 의미 없음.
- Block size 192 + register tiling — Hopper의 더 큰 SRAM(228 KB)을 활용해 BLOCK_M을 키우는 시도. Triton의
tl.arangepower-of-2 제약은 register-tiling으로 우회 가능.
이번 포스트에서 시도하지 않은 이유는 단순하다 — A100에서는 이 셋이 의미가 없거나 동작하지 않는다. Hopper 환경에서만 측정 가능한 영역이므로 검증할 수 없는 코드를 미리 작성하지 않았다.
Production용 권장
| 환경 | 추천 |
|---|---|
| 학습/연구 | 본 Triton FA3 (충분히 빠르고 알고리즘 이해에 좋음) |
| Production A100 | 본 Triton FA3 또는 flash-attn 패키지 |
| Production H100 | flash-attn 패키지 (pip install flash-attn) — Triton으로 쫓아갈 가성비 X |
| 새 attention 변형 (sliding window 등) | CUTLASS 기반 fork 또는 Triton 프로토타이핑 후 CUDA 이식 |
시리즈 정리
Triton 05 (FA1) → Triton 06 (FA2) → 본 글까지의 누적 결과 (A100, causal, head_dim=64):
| Seq | FA1 (ms) | FA2 (ms) | FA3 (ms) | FA1→FA2 | FA2→FA3 | FA3 vs PT |
|---|---|---|---|---|---|---|
| 4096 | 0.571 | 0.361 | 0.350 | 1.58× | 1.03× | 14.97× |
| 8192 | 1.721 | 1.033 | 0.992 | 1.67× | 1.04× | 21.98× |
| 16384 | 5.972 | 3.556 | 3.391 | 1.68× | 1.05× | 20.90× |
| 32768 | 22.247 | 13.426 | 12.847 | 1.66× | 1.05× | — |
FA1 → FA2 의 ~1.6× 점프와 비교하면 FA2 → FA3 의 ~1.04× 는 작아 보이지만, 알고리즘 자체는 거의 한계에 도달했고 그 이상은 하드웨어 특화 기법으로 가야 한다는 신호다.
전체 코드
알고리즘 원리가 궁금하다면 FlashAttention-3 논문 리뷰를, FA1 Triton 구현이 궁금하다면 Triton 05를, FA2 개선이 궁금하다면 Triton 06을, Blackwell 최적화가 궁금하다면 FlashAttention-4 논문 리뷰를 참고하자.
참고 문헌
Enjoy Reading This Article?
Here are some more articles you might like to read next: