DeepSeek 계열 MoE 학습 가속: Python expert loop → grouped GEMM

Introduction

DeepSeek-V2/V3 의 공개 modeling 코드를 그대로 SFT 학습에 쓰면, MoE 가 학습 step time 의 사실상 전부를 잡아먹는다. 원인은 modeling 의 forward 가 학습 경로를 사실상 제공하지 않고, 사용자가 직접 작성하게 되어 있는 moe_train 이 보통 128~256 expert 를 Python for loop 으로 도는 naive 구현이라는 점에 있다.

이 글의 결과는 한 줄로 다음과 같다. 단일 GEMM 으로 묶어서 풀자. Variable-M grouped matmul (이하 grouped GEMM) 로 expert MLP 의 gate/up/down 세 번을 각각 1 번씩 부르면, 단일 GPU 마이크로벤치에서 6.69×, A100 8 GPU FSDP smoke 에서 end-to-end 6.27× 의 가속이 나온다.

측정 naive Python loop grouped GEMM
단일 레이어 (N=4096, fwd+bwd) 87.9 ms 13.1 ms (6.69×)
8 GPU FSDP per-step (smoke) 110.1 s 17.55 s (6.27×)
Peak GPU memory 동일
Loss 차이 (5 step) ≤ 0.16% rel (bf16 노이즈 내)

Background — DeepSeek 계열 MoE 구조

먼저 DeepSeek-V3 의 modeling_deepseek.py 의 MoE 모듈을 그대로 보자.

이상한 부분이 두 가지 보인다.

  1. forwardif not self.training: 분기에서만 y 를 정의한다. 학습 모드에서 y 는 정의되지 않은 상태로 shared expert 와 합쳐지면서 NameError 가 난다. 사용자가 직접 moe_train 을 구현해서 monkey-patch 또는 코드 수정으로 끼워 넣어야 한다.
  2. 추론 경로 moe_infer@torch.no_grad() 데코레이터가 붙어 있어 학습용으로 재사용 불가.

학습용 forward 를 직접 짜야 하니, 가장 자연스러운 형태가 다음과 같은 Mixtral-style scatter-gather 다.

매 forward 마다 num_experts × 3 = 768 개 작은 GEMM kernel 을 직렬로 launch 한다. A100 의 CUDA kernel launch latency 는 ~5–10 μs, 그래서 768 launch × 7 μs ≈ 5.4 ms 가 순수 launch overhead. 거기다 expert 당 평균 token 수는 N × top_k / num_experts 인데, N=4096 / top_k=8 / E=256 이면 expert 당 평균 128 token — 이 정도 크기의 GEMM 은 A100 의 tensor core 를 채우지 못해 GPU 가 거의 idle 상태가 된다.


무엇이 필요한가 — Variable-M batched matmul

해결책의 형태는 명확하다. Expert 별로 token 을 정렬한 뒤, 한 번의 kernel call 로 모든 expert 의 GEMM 을 처리하면 launch overhead 도 사라지고 GPU utilization 도 올라간다. 이를 위해 필요한 연산은 다음과 같다.

연산자 모든 그룹 GEMM 크기 API
torch.matmul / F.linear 동일 단일 행렬
torch.bmm 모든 그룹 같은 M batched (3-D)
gg.gmm (grouped GEMM) 그룹마다 다른 M grouped
torch.ops.aten._grouped_mm (torch 2.7+) 그룹마다 다른 M native (cu126+ 필요)

torch.bmm 은 모든 그룹의 M 이 같아야 하므로 expert routing 처럼 그룹 크기가 들쭉날쭉하면 zero-padding 으로 낭비가 발생한다. 진짜 “ragged batch” GEMM 는 cutlass 의 GroupedGemm kernel 이 표준 구현이며, tgale96/grouped_gemm 이 그 cutlass kernel 의 PyTorch autograd 바인딩이다 — MegaBlocks 의 block-sparse path 와 같은 author 의 작업.

API 는 다음과 같이 단순하다.


적용 — fused MoE training forward

이제 위 naive moe_train 을 grouped GEMM 으로 다시 작성한다.

핵심은 다음 세 가지다.

  • (1)–(3) 토큰 정렬: topk_ids 를 1D 로 펼친 뒤 expert id 로 sort 하면, 같은 expert 로 가는 토큰들이 연속된 메모리 슬라이스로 모인다. torch.bincount 가 expert 별 토큰 수를 한 번에 계산해 준다.
  • (5) Weight stacking: nn.Linear.weight(out, in) 레이아웃이므로, 그대로 torch.stack 하면 (E, out, in) 가 된다. trans_b=True 를 쓰면 grouped GEMM 이 내부에서 transpose 를 처리하므로 .transpose().contiguous() 의 메모리 copy 를 피할 수 있다 (autograd 도 정상 추적).
  • (6) 세 번의 gg.gmm: gate → up → down 의 expert MLP 세 단계 모두 동일한 group_sizes 를 사용한다. 768 launch 가 3 launch 로 줄어든다.

forward 에서 학습 모드에 이 함수를 호출하도록 분기시키면 끝이다.

moe_train 은 위의 fused_moe_train 을 monkey-patch 형태로 클래스에 붙이면 modeling 파일 자체는 손대지 않을 수 있다.


FSDP 와의 통합

이 fused forward 는 FSDP FULL_SHARD 와 자연스럽게 호환된다. 단, 한 가지 조건이 있다.

  • fsdp_use_orig_params=True 가 켜져 있어야 한다 (accelerate 의 FSDP config 에서 기본값이 false 일 수 있음). 이 옵션이 켜지면 self.experts[i].gate_proj.weight 같은 원본 Parameter 객체가 그대로 노출되어, torch.stack 이 정상 동작한다.
  • FSDP 의 wrap unit 은 transformer block 단위 (TRANSFORMER_BASED_WRAP) 가 자연스럽다. block 단위로 all-gather 가 일어나면, block 내부의 fused MoE forward 시점에서 expert weight 가 이미 로컬에 모여 있다.

E=256, H=2048, moe_intermediate=512 기준으로 stack 결과의 transient 메모리는 3 × 256 × 2048 × 512 × 2B ≈ 1.5 GB. A100 80GB 에선 충분히 감당 가능하고, 사용 직후 해제된다. 만약 E 가 훨씬 크거나 메모리가 빠듯한 환경이면, 세 weight 를 한꺼번에 쌓지 말고 GEMM 직전에 하나씩 쌓는 식으로 peak 를 낮출 수 있다.


Experiments

마이크로벤치 (단일 GPU, A100)

E=128, H=2048, moe_intermediate=512, top_k=8, fwd+bwd 30 iter 평균.

Token 수 N naive Python loop grouped GEMM Speedup
N=4096 87.94 ms 13.14 ms 6.69×

원인 분해:

  • Kernel launch overhead 제거 (768 → 3)
  • 작은 expert GEMM (avg 256 행) 들이 큰 grouped GEMM 한 번으로 묶이면서 tensor core utilization ↑

End-to-end FSDP smoke (A100 × 8, FULL_SHARD)

bf16, gradient_checkpointing on, per_device_bs=1, grad_accum=16, max_length=8192 SFT smoke. 5 step.

Stack per-step (steady) train_runtime Cumulative speedup
baseline (naive Python loop) 110.1 s 572 s 1.00×
+ fused MoE (grouped GEMM) 17.55 s 107 s 6.27×

수치 동등성: 5 step loss 의 max relative diff 0.16% (bf16 ULP 노이즈 내), peak GPU memory 변화 없음.

흥미로운 점은 마이크로벤치 6.69× 가 end-to-end 6.27× 와 거의 일치한다는 사실이다. 이는 SFT 학습 step time 의 사실상 전부가 MoE forward + backward 였다는 것을 뜻한다. attention 이나 FSDP all-gather 가 의미 있게 보이려면 이 병목을 먼저 걷어내야 측정이 가능하다.


한계

  • ep_size > 1 미지원: Expert Parallel 환경에선 각 rank 가 자신이 소유한 expert subset 만 들고 있고, all-to-all 통신이 필요하다. 이 경우 grouped GEMM 만으로는 부족하고 DeepEP 같은 token dispatch + combine kernel 이 함께 필요하다. 본 글의 fused path 는 단일 노드 / 단일 rank 가 모든 expert 를 소유하는 FSDP 만 다룬다.
  • 추론은 그대로: moe_infer 의 sorted-token argsort + Linear loop 패턴은 추론 시점에선 KV cache, paged attention 등과 함께 묶여 vLLM/SGLang 의 dispatch 와 따로 다뤄야 한다.
  • DeepGEMM (FP8) 미적용: DeepSeek-V3 본가는 H100/H800 에서 fine-grained FP8 grouped GEMM 을 자체 구현해 추가 2× 를 얻는다. A100 은 FP8 unit 부재로 bf16 grouped GEMM 까지가 한계.

Conclusion

DeepSeek-V3 계열 MoE 모델은 공개 modeling 코드가 학습용 MoE forward 를 사실상 제공하지 않는다. 사용자가 직접 작성하는 순간 가장 흔한 형태가 naive Python for loop 인데, 이게 step time 의 거의 전부를 차지한다. cutlass grouped GEMM 한 번이면 단일 GPU 6.69×, end-to-end 6.27× 가속이 나온다. 이 정도면 다른 모든 최적화는 이걸 끝낸 뒤에 시작하는 게 맞다.


참고 문헌




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • TRL sequence packing → DeepSeek MLA: 누락된 cu_seqlens 복원
  • MLA 학습 시 modeling-side projection fusion: q_a/kv_a 배치 + K-side absorption
  • LLM 엔지니어가 알아야 할 GPU 아키텍처: Ampere → Hopper → Blackwell
  • Keras 3.0 설명
  • Jetson Nano Tensorrt 적용