TRL sequence packing → DeepSeek MLA: 누락된 cu_seqlens 복원
Introduction
이전 두 글에서 DeepSeek 계열 MoE 학습 가속을 다뤘다 — MoE grouped GEMM fusion (6.27×) 과 MLA projection fusion (A+B+D) (1.05×). 그 위에 적용할 다음 카드가 sequence packing. 평균 문장 길이가 max_length 보다 한참 짧은 SFT 데이터셋이면 padding 영역이 GEMM compute 의 90% 이상을 잡아먹기 때문에, 짧은 sample 여러 개를 한 slot 에 묶어 채우면 effective throughput 이 dataset density 만큼 폭증한다.
그런데 DeepSeek-V3 modeling 코드를 그대로 쓰면 packing 을 켜는 순간 loss 가 망가진다 — 우리 환경에서 step 1 loss 가 2.57 (no packing) → 5.70 (packing) 으로 폭주했다. 모델이 거의 random initialization 수준의 예측을 내놓는다.
이 글은 그 망가짐의 정확한 원인을 추적하고, modeling 파일을 손대지 않은 채 attention dispatcher 한 군데에서 position_ids 의 0-reset 위치로부터 cu_seqlens 를 복원해 해결한 과정을 정리한다. 결과는 다음과 같다.
| 단계 | per-step (8 GPU FSDP) | loss step 1 | peak GPU mem |
|---|---|---|---|
| baseline (no packing) | 16.65 s | 2.589 | 57.7 GB |
| packing on, 깨짐 | 4.45 s | 5.701 ← random | 33.0 GB |
| dispatcher fix 1 (kwargs) | 20.40 s | 5.700 ← 변화 없음 | 33.0 GB |
| dispatcher fix 2 (position_ids) | 3.58 s | 1.855 ← 학습 정상 | 25.1 GB |
마지막 줄에서 wall-time 4.65× 가속 + 메모리 25 GB / 학습 정합성 회복 이 동시에 달성됐다.
Background — TRL sequence packing 의 동작
TRL 1.2 의 SFTConfig(packing=True, packing_strategy="bfd") 를 켜면 다음이 자동으로 일어난다.
- 데이터 단: dataset 의 각 sample 길이를 측정하고 BFD (Best-Fit-Decreasing) 알고리즘으로
max_lengthslot 에 패킹. 한 slot 에 평균 50–80 개 짧은 doc 이 들어간다. -
padding_free=True자동 활성: padding 토큰을 아예 만들지 않는다. batch 의 shape 가(1, total_packed_length)로 고정. - 메타데이터 변경:
attention_mask가 batch 에서 사라지고 대신 doc 경계를 표현하는 다음 키들이 들어간다.-
position_ids— 각 doc 마다 0 으로 reset -
cu_seq_lens_q,cu_seq_lens_k— doc 경계 cumulative seq lens -
max_length_q,max_length_k— 가장 긴 doc 길이
-
packed batch 의 실제 모양을 보면 다음과 같다.
여기서 핵심은 두 가지다.
-
position_ids는 doc 경계 정보를 그대로 갖고 있다 — 0 으로 reset 되는 위치가 곧 doc 시작점. -
attention_mask가 batch 에 없다 — 따라서 modeling 의 FA2 forward 는 이 사실을 받아 처리해야 한다.
무엇이 깨지는가 — _flash_attention_forward 의 비-varlen 분기
DeepSeek-V3 의 modeling_deepseek.py 안 DeepseekV3FlashAttention2._flash_attention_forward 를 그대로 보면 다음과 같다.
분기 두 개가 핵심이다.
-
attention_mask is not None—_upad_input으로 cu_seqlens 를 만들어flash_attn_varlen_func호출. 단 cu_seqlens 는attention_mask.sum(dim=-1)로 derive 되므로 padding 영역만 잘라낼 뿐, 한 row 안의 doc 경계는 안 본다. 즉 packed 시 padding 이 없으므로 attention_mask 가 batch 에 아예 없거나 전부 1. -
attention_mask is None— 그냥flash_attn_func(non-varlen) 호출. 전체 (1, S, H, D) 텐서를 단일 causal sequence 로 처리.
TRL 의 padding_free packing 은 후자로 빠진다. 결과적으로 한 row 안에 packing 된 50–80 개의 서로 다른 doc 의 토큰들이 서로 attend 한다. causal mask 가 적용되긴 하지만 그건 row 내 absolute 위치 기준이라, doc 경계 따위는 모른다. 학습 signal 이 망가지는 게 당연하다.
loss 측정으로 다시 확인하면 packed step 1 의 loss 가 5.70, mean token accuracy 가 0.25 (≈ random) 이다. entropy 6.0 (uniform 에 가까움). 모델이 packed input 에 대해 “이건 학습된 적 없는 분포다” 라고 말하고 있다.
첫 번째 시도 (실패): kwargs 경로로 cu_seq_lens 받기
TRL data collator 가 batch 에 cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k 를 넣어 보낸다. 자연스러운 fix 는 attention forward 에서 **kwargs 로 받아 그대로 flash_attn_varlen_func 에 넘기는 것. 우리도 처음엔 이 방향으로 시도했다.
그러나 결과는 loss 가 bit-identical 로 동일하게 망가진 상태였다. 5.7010 vs 5.7003 — bf16 ULP 노이즈 한도 안. 즉 fix 가 발동조차 안 했다.
원인을 추적해 보니 다음과 같다.
DeepseekV3Model.forward (그리고 우리 모델의 동형 AXK1Model.forward) 의 시그니처에 **kwargs가 없다. TRL 이 batch 에 정성스럽게 넣어 보낸cu_seq_lens_q/k, max_length_q/k가 모델 진입부에서 silently drop 된다. 그래서 attention 까지 도달했을 때 우리 dispatcher 의kwargs.get("cu_seq_lens_q")는 항상None. fallback 으로 다시 비-varlen 경로로 빠지면서 학습 망가짐 그대로.
이걸 고치려면 modeling 코드 자체를 수정해야 하는데, 우리는 modeling_axk1.py / modeling_deepseek.py 를 손대지 않는다는 원칙을 지키고 싶었다. 즉 모델 forward signature 변경 없이 packing 정보를 attention 까지 전달할 다른 길이 필요했다.
두 번째 시도 (성공): position_ids 에서 cu_seqlens 복원
해결의 단서는 position_ids 는 살아 있다 는 점이다. DeepseekV3Model.forward 의 명시 인자라서, 사용자 코드를 손대지 않아도 attention forward 까지 정상 전파된다. 그리고 packed batch 에선 position_ids 가 doc 마다 0 으로 reset 된다. 즉 cu_seqlens 정보는 position_ids 안에 인코딩돼 있다.
이걸 derive 하는 헬퍼는 한 줄 수준이다.
이제 attention dispatcher 를 다음처럼 짠다.
세 갈래로 분기한다.
- kwargs 경로: 만약 modeling 이 후일 수정돼서
cu_seq_lens_q가 attention 까지 도달할 수 있으면, 그걸 그대로 쓴다. - position_ids 자동 derive: kwargs 가 비어 있어도
attention_mask is None+position_ids에 0 이 두 번 이상 나오면 packed 로 인식하고 헬퍼로 cu_seqlens 를 만든다. - padded fallback: 둘 다 아니면 원래
_flash_attention_forward위임 (기존 padded path 변경 없음).
flash_attn_varlen_func 가 cu_seqlens 를 받으면 doc 경계를 정확히 인식해서 cross-doc attention 을 0 으로 처리한다. softmax_scale, causal flag 등 다른 모든 설정은 동일하게 유지.
마지막으로 modeling 의 attention forward 를 위 dispatcher 를 거치는 fused 버전으로 monkey-patch 한다 — 같은 시리즈의 MLA projection fusion 글 의 fused_mla_forward_ab 안에 dispatcher 가 들어가는 형태로 두면 깔끔하다.
결과
A100 × 8 FSDP smoke (5 step, max_length=8192, KoAlpaca-RealQA, per_device_bs=1, grad_accum=16, fused_moe + fused_mla A+B + FA2 stack).
| metric | baseline | broken packing | fix 1 (kwargs) | fix 2 (position_ids) |
|---|---|---|---|---|
| per-step (s2–s5)/4 | 16.65 s | 4.45 s | 20.4 s | 3.58 s |
| train_runtime (5 step) | 105.4 s | 43.7 s | 109.1 s | 39.4 s |
| Peak GPU mem | 57.7 GB | 33.0 GB | 33.0 GB | 25.1 GB |
| Loss step 1 | 2.589 | 5.701 | 5.700 | 1.855 |
| Loss step 5 | 2.568 | 5.686 | 5.686 | 1.849 |
| entropy step 1 | 1.85 | 6.01 | 6.01 | 1.48 |
| mean_token_accuracy | 0.62 | 0.25 | 0.25 | 0.64 |
| grad_norm | 17.5 | 20.75 | 20.75 | 8.2 |
세 가지 관찰.
- Loss 가 1.85 로 baseline 2.57 보다 오히려 낮다. 이는 packing 으로 한 row 안에 실제 토큰 비율이 ~1% (no packing) → ~100% (packing) 로 바뀌면서, loss 평균이 padding 토큰이 아닌 의미 있는 토큰들 위에서 계산되기 때문. 실제 학습 진행은 step 마다 작지만 일관된 감소 (1.855 → 1.849) 로 확인된다.
- 메모리 절반 (57.7 GB → 25.1 GB). padding zero 영역의 activation 을 만들지 않아서.
- per-step wall-time 16.65 → 3.58 s = 4.65×. 그리고 broken packing (4.45 s) 보다도 약간 더 빠르다 — varlen path 가 non-varlen path 보다 효율적이라는 뜻 (zero-padding 영역을 진짜로 건너뛰므로).
이전 두 글의 가속과 합치면 누적 효과는 다음과 같다.
| stack | per-step | vs 원본 |
|---|---|---|
| 원본 (naive Python loop, no packing) | 110.1 s | 1.00× |
| + fused MoE (grouped GEMM) | 17.55 s | 6.27× |
| + fused MLA (A+B) | 16.65 s | 6.61× |
| + packing (with position_ids-derived cu_seqlens) | 3.58 s | 30.8× |
한계와 함께 다룰 것
- B=1 가정: padding_free packing 은 batch row 가 1 인 상황을 전제로 한다. multi-batch packed 모드를 쓰면 dispatcher 의
assert bsz == 1분기를 풀어야 한다. flash_attn_varlen_func 의 입력 layout 도 그에 맞춰 unpadded(total_tokens, H, D)로 재구성 필요. -
gradient_checkpointing=false와의 양립: packing 으로 real-token 비율이 ↑ 한 결과 activation 메모리도 비례해서 ↑. grad_ckpt 끄면 80 GB A100 1장에 들어가지 않는다 (smoke 검증 OOM). packing 환경에선 grad_ckpt 는 켜둬야 한다. - TRL 의
wrappedpacking strategy: BFD/BFD_split 이 아닌 wrapped 전략은 doc 경계를 다르게 표시할 수 있다. 우리는 BFD 만 검증했다. - modeling 코드를 직접 수정한다면:
DeepseekV3Model.forward에**kwargs를 추가해 TRL 이 보낸 cu_seq_lens 를 그대로 받으면 position_ids 의존도가 사라지고 dispatcher 가 더 단순해진다. 본 글의 fix 는 modeling 무수정을 전제로 한 우회로다.
Conclusion
SFTConfig(packing=True) 한 줄로 4–10× 가속이 나오는 영역인데, DeepSeek-V3 reference modeling 과 TRL padding_free 의 조합은 attention dispatcher 에 미세한 정보 결손이 있어 학습이 조용히 망가진다. cu_seqlens 가 모델 입구에서 drop 되더라도 position_ids 는 끝까지 살아남는다는 사실을 활용하면, modeling 파일 무수정 + 한 군데 dispatcher 패치만으로 학습 정합성을 회복하고 가속 효과를 그대로 가져갈 수 있다.
본인의 환경에서 packing 을 켜고 loss 가 평소보다 두 배 이상으로 폭주한다면, 가장 먼저 의심해야 하는 곳이 정확히 이 지점이다.
참고 문헌
Enjoy Reading This Article?
Here are some more articles you might like to read next: