Triton 02: Fused Softmax — 커널 퓨전과 Reduction

개요

Softmax를 하나의 커널로 퓨전(fusion)하여 메모리 접근을 최소화합니다. 커널 퓨전이 왜 중요한지, reduction 연산을 어떻게 처리하는지 학습합니다.


핵심 개념

Softmax 수식

softmax(x_i) = exp(x_i - max(x)) / Σ exp(x_j - max(x))

max(x)를 빼는 이유: exp는 큰 값에서 오버플로우가 발생합니다. 최대값을 빼면 모든 지수가 0 이하가 되어 안정적으로 계산됩니다.

왜 커널 퓨전인가?

Reduction 연산

전체 데이터에서 하나의 값을 계산하는 연산:

  • max: 최대값
  • sum: 합계
  • mean: 평균

Triton에서는 tl.max(x, axis=0), tl.sum(x, axis=0) 으로 간단하게 수행합니다.


커널 동작 원리

입력 행렬의 각 행(row) 을 하나의 프로그램이 처리합니다.


코드 라인별 설명

커널 함수

핵심: max → exp → sum → 나누기를 전부 SRAM 안에서 처리. PyTorch는 이 4단계를 각각 별도 커널로 실행하므로 매번 Global Memory를 왕복합니다.

래퍼 함수


01 Vector Add와의 차이점

  01 Vector Add 02 Fused Softmax
처리 단위 1D 벡터의 청크 2D 행렬의 행
프로그램당 연산 덧셈 1번 max+exp+sum+나누기
퓨전 효과 없음 (연산이 1개) 4개 연산을 1커널로
새로운 기능 - tl.max, tl.sum, tl.exp, stride

벤치마크 결과

커널 퓨전 덕분에 메모리 대역폭을 절약하여, 특히 열(column) 수가 클수록 PyTorch 대비 성능 향상이 눈에 띕니다.


전체 코드




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • LLM 엔지니어가 알아야 할 GPU 아키텍처: Ampere → Hopper → Blackwell
  • FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling
  • FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
  • Triton 05: Flash Attention — 종합 프로젝트
  • Triton 04: Matrix Multiplication — 2D 타일링과 Autotune