Triton 03: RMSNorm — LLM에서 쓰이는 실전 커널

개요

LLaMA, Mistral, Gemma 등 최신 LLM에서 사용하는 RMSNorm을 Triton으로 구현합니다. Softmax와 유사한 패턴이지만, 학습 가능한 가중치(gamma)가 추가됩니다.


핵심 개념

LayerNorm vs RMSNorm

LayerNorm:  y = (x - mean(x)) / sqrt(var(x) + ε) * γ + β
RMSNorm:    y = x / sqrt(mean(x²) + ε) * γ

RMSNorm이 LLM에서 선호되는 이유:

  • mean 계산이 필요 없음 → 연산량 감소
  • bias(β) 없음 → 파라미터 수 감소
  • 실험적으로 LayerNorm과 성능이 비슷

수식 분해

1. 제곱합:    sum_sq = Σ(x_i²)
2. RMS:       rms = sqrt(sum_sq / n + ε)
3. 정규화:    x_norm = x / rms
4. 스케일링:  y = x_norm * γ

커널 동작 원리


코드 라인별 설명

PyTorch 참조 구현

커널 함수

래퍼 함수


02 Fused Softmax와의 차이점

  02 Softmax 03 RMSNorm
reduction max + sum (2번) sum (1번)
수치 안정성 max 빼기 eps 더하기
범위 밖 채움 -inf 0.0
추가 입력 없음 가중치 γ
입력 shape 2D만 3D/4D → 2D 변환

벤치마크 결과

PyTorch의 수동 RMSNorm 구현 대비 커널 퓨전으로 인한 성능 향상이 나타납니다. hidden_size가 클수록(2048, 4096 등) 차이가 명확합니다.


전체 코드




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • LLM 엔지니어가 알아야 할 GPU 아키텍처: Ampere → Hopper → Blackwell
  • FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling
  • FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
  • Triton 05: Flash Attention — 종합 프로젝트
  • Triton 04: Matrix Multiplication — 2D 타일링과 Autotune