Triton 04: Matrix Multiplication — 2D 타일링과 Autotune

개요

딥러닝의 핵심 연산인 행렬 곱셈(GEMM)을 Triton으로 구현합니다. 2D 타일링, tl.dot, triton.autotune 등 고급 기능을 학습합니다.


핵심 개념

행렬 곱셈이 왜 중요한가

딥러닝의 거의 모든 연산이 행렬 곱셈:

  • Linear layer: y = xW + b
  • Attention: QK^T, PV
  • MLP: 모든 Feed-Forward 블록

나이브 vs 타일링

나이브: 출력의 각 원소마다 Global Memory에서 행/열 전체를 읽음 → 같은 데이터를 반복 로드

타일링: 행렬을 작은 블록으로 나누어 SRAM에 올리고, 블록 단위로 계산

행렬 곱셈 타일링 전략 도식

커널 동작 원리

2D 그리드

이전 튜토리얼은 1D 그리드(행 단위)였지만, MatMul은 2D 그리드를 사용합니다:

행렬 곱셈용 2D 그리드 매핑

K 차원 루프

행렬 곱셈 C = A × B에서 A(M×K), B(K×N)일 때, K가 크면 한 번에 SRAM에 못 올립니다. 그래서 K를 BLOCK_SIZE_K씩 잘라서 반복하며, 부분 결과를 누적합니다.

K 차원 루프에서 포인터 이동 과정

L2 캐시 최적화 (Swizzling)

Swizzling = “같은 B 블록을 쓰는 프로그램들을 묶어서 실행”

프로그램 그룹 순서 개념도
Swizzling 상세 동작 다이어그램
L2 캐시 Swizzling 패턴

triton.autotune 이란?

블록 크기에 따라 성능이 크게 달라집니다. Autotune은 여러 설정을 실행해보고 가장 빠른 것을 선택합니다:


코드 라인별 설명

K 차원 루프 (핵심)

이전 튜토리얼과의 차이점

  01~03 04 MatMul
그리드 1D (행 수) 1D (M타일 × N타일)
데이터 1D 벡터/행 2D 블록 (타일)
루프 없음 K 차원 루프
핵심 연산 +, exp, sum tl.dot (텐서 코어)
파라미터 튜닝 수동 BLOCK_SIZE triton.autotune

벤치마크 결과

행렬 곱셈 성능 벤치마크 결과

cuBLAS(torch.matmul)는 수십 년간 최적화된 라이브러리입니다. Triton으로 cuBLAS의 80~90% 성능에 도달하는 것이 목표입니다.


전체 코드




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Triton 07: Flash Attention 3 — Triton으로 어디까지 가능한가
  • Triton 06: Flash Attention 2 — FA1 대비 5가지 최적화
  • Triton 05: Flash Attention — 종합 프로젝트
  • Triton 03: RMSNorm — LLM에서 쓰이는 실전 커널
  • Triton 02: Fused Softmax — 커널 퓨전과 Reduction