Data Selection for Language Models via Importance Resampling

Data Selection for Language Models via Importance Resampling

Introduction

언어 모델의 성능은 학습 데이터의 품질에 크게 좌우된다. GPT-3는 Common Crawl에서 품질 필터링을 거친 데이터를 사용했고, PaLM은 소셜 미디어 대화와 웹 문서를 섞어서 썼다. 하지만 이런 데이터 선택 과정은 대부분 휴리스틱에 의존하거나, 전문가가 수작업으로 큐레이션하는 방식이었다.

이 논문은 데이터 선택 문제를 분포 매칭(distribution matching) 관점에서 형식화한다. 핵심 질문은 이렇다: “대규모 raw 데이터셋에서, 원하는 타겟 분포와 가장 비슷한 서브셋을 어떻게 효율적으로 선택할 수 있을까?”

저자는 DSIR (Data Selection with Importance Resampling)을 제안한다. 통계학의 importance resampling 기법을 텍스트 데이터에 적용하되, hashed n-gram feature 공간에서 분포를 추정하여 계산 비용을 극적으로 낮춘다. 결과적으로 The Pile의 1억 문서를 4.5시간 만에 단일 CPU 노드에서 처리할 수 있다.

Background

Importance Resampling이란

Importance resampling은 한 분포(raw 분포 \(q\))에서 뽑은 샘플을, 다른 분포(타겟 분포 \(p\))에서 뽑은 것처럼 변환하는 기법이다. 핵심은 importance weight를 계산하는 것이다:

\[w(x) = \frac{p(x)}{q(x)}\]

타겟 분포 \(p\)에서 나올 확률이 높은 샘플일수록 weight가 크고, raw 분포 \(q\)에서만 자주 나오는 샘플은 weight가 작다. 이 weight에 비례하여 리샘플링하면, 결과적으로 타겟 분포에 가까운 서브셋을 얻을 수 있다.

문제: 고차원 텍스트에서의 밀도 추정

하지만 텍스트는 고차원 데이터이다. Raw 텍스트 공간에서 \(p(x)\)와 \(q(x)\)를 직접 추정하는 것은 통계적으로 불가능하다(intractable). 그래서 DSIR은 저차원 feature 공간에서 분포를 추정하는 접근을 취한다.

Method

Step 1: Hashed N-gram Feature Extraction

DSIR은 텍스트를 hashed n-gram feature로 변환한다. 구체적으로:

  1. 각 텍스트 \(x\)에서 unigram과 bigram 목록을 생성한다
  2. 각 n-gram을 \(m\)개의 버킷 중 하나로 해싱한다 (논문에서는 \(m = 10{,}000\))
  3. 각 버킷의 등장 횟수를 세어 \(m\)차원 벡터 \(z \in \mathbb{N}^m\)을 만든다

예를 들어 “Alice is eating”이라는 텍스트는:

  • Unigrams: [Alice, is, eating]
  • Bigrams: [Alice is, is eating]
  • 각 n-gram을 해싱 → 해시 인덱스 [1, 3, 3, 2, 0]
  • 카운트 벡터 → \(z = [1, 1, 1, 2, 0, \ldots]\)

이렇게 하면 텍스트의 표면적 어휘 분포를 포착하면서도, \(m = 10{,}000\)차원이라는 다루기 쉬운 공간으로 축소할 수 있다.

Step 2: Bag-of-Ngrams 생성 모델로 분포 추정

Feature 공간에서 raw 분포 \(\hat{q}_\text{feat}\)과 타겟 분포 \(\hat{p}_\text{feat}\)를 추정해야 한다. DSIR은 bag-of-ngrams 생성 모델을 사용한다:

\[\mathbb{P}(z; \gamma) = \prod_{j=1}^{m} \gamma[j]^{z[j]}\]

여기서 \(\gamma[j]\)는 \(j\)번째 해시 버킷의 발생 확률이다. 이 파라미터는 단순히 빈도를 세는 것만으로 추정할 수 있다:

\[\hat{\gamma} = \frac{1}{\sum_i \mathbf{1}^\top \tilde{z}_i} \sum_j \tilde{z}_j\]

즉, 전체 데이터에서 각 해시 버킷이 나타나는 비율을 계산하면 된다. 이 과정을 raw 데이터와 타겟 데이터 각각에 대해 수행하면 \(\hat{q}_\text{feat}\)와 \(\hat{p}_\text{feat}\)를 얻는다.

Step 3: Importance Resampling

각 raw 데이터 포인트 \(z_i\)에 대해 importance weight를 계산한다:

\[w_i = \frac{\hat{p}_\text{feat}(z_i)}{\hat{q}_\text{feat}(z_i)}\]

그리고 이 weight에 비례하는 확률로 \(k\)개의 샘플을 비복원 추출(without replacement)한다. 비복원 추출은 Gumbel top-k trick을 사용하여 효율적으로 수행한다.

이 세 단계를 정리하면 다음과 같다:

단계 내용 계산 비용
1. Feature 추출 hashed n-gram 벡터 계산 전체 시간의 ~97%
2. 분포 추정 빈도 카운팅으로 \(\hat{p}, \hat{q}\) 학습 무시할 수준
3. 리샘플링 weight 기반 비복원 추출 6분 이내

DSIR 파이프라인의 핵심은 단순함이다. 뉴럴 네트워크 학습도 없고, GPU도 필요 없다. 해시 버킷 빈도만 세면 된다.

KL Reduction: 데이터 선택 품질 측정 지표

저자는 데이터 선택 방법의 품질을 측정하는 지표로 KL reduction을 제안한다:

\[\text{KL-reduction}(p'_\text{feat}; \hat{q}_\text{feat}, \mathcal{T}) = \frac{1}{|\mathcal{T}|} \sum_{\hat{p}_\text{feat} \in \mathcal{T}} \left[ \text{KL}(\hat{p}_\text{feat} \| \hat{q}_\text{feat}) - \text{KL}(\hat{p}_\text{feat} \| p'_\text{feat}) \right]\]

직관적으로 해석하면, “선택된 데이터의 분포 \(p'_\text{feat}\)이 랜덤 선택(\(\hat{q}_\text{feat}\)) 대비 타겟 분포에 얼마나 더 가까워졌는가”를 측정하는 것이다. KL reduction이 클수록 타겟 분포에 더 잘 매칭된 것이다.

8개의 데이터 선택 방법에 걸쳐 KL reduction과 downstream 정확도 사이의 Pearson 상관계수가 r = 0.82로 매우 높다. 이는 hashed n-gram feature 공간에서의 KL reduction이 실제 downstream 성능을 잘 예측한다는 것을 의미한다.

Experiments

Domain-Specific Continued Pretraining

8개의 도메인별 타겟 데이터셋(CS, Biomedical, News, Reviews)에 대해 The Pile에서 데이터를 선택하고, RoBERTa를 continued pretraining한 뒤 downstream 태스크를 평가했다.

Method Avg F1
RoBERTa (continued pretrain 없음) 82.20
Random selection 82.58
Heuristic classification 82.88
Top-k Heuristic 83.65
DAPT (전문가 큐레이션) 83.46
DSIR 83.76
  • DSIR은 전문가 큐레이션(DAPT)보다 0.3% 높은 성능을 달성했다
  • 랜덤 선택 대비 1.18% 향상
  • 8개 도메인 중 대부분에서 top-k보다 같거나 높은 성능

위 히트맵은 각 타겟 분포로 선택한 pretraining 데이터가 서로 다른 fine-tune 태스크에 얼마나 잘 전이되는지 보여준다. 같은 도메인 내에서의 F1 (82.9%)이 교차 도메인 F1 (81.2%)보다 1.7% 높아, 도메인 매칭이 실제로 중요하다는 것을 확인할 수 있다.

DSIR이 선택하는 데이터의 분포

DSIR이 어떤 데이터를 선택하는지 살펴보면 직관적으로 이해할 수 있다.

왼쪽은 랜덤 선택, 오른쪽은 DSIR 선택의 Pile 도메인별 비율이다. 타겟이 Wikipedia+Books일 때, DSIR은 Pile-CC (49%)와 OpenWebText2 (30%)에 집중하는 반면, 랜덤은 모든 도메인에서 고르게 뽑는다. DSIR이 타겟과 유사한 텍스트 분포를 가진 도메인을 자동으로 식별한다는 것을 보여준다.

도메인별 타겟에서는 더 극적이다:

ChemProt(생화학) 타겟의 경우, DSIR은 PubMed Abstracts (48%)와 PubMed Central (28%)을 집중 선택한다. 생물의학 논문 도메인이 화학-단백질 관계 추출 태스크에 가장 유용한 pretraining 데이터라는 직관과 정확히 일치한다.

General-Domain Pretraining

타겟을 Wikipedia + Books로 설정하여 범용 언어 모델을 학습한 결과:

Method GLUE Avg
Random selection 80.25
Heuristic classification 79.85
Top-k Heuristic 81.47
DSIR 82.30
Top-k DSIR 81.38

DSIR은 랜덤 선택 대비 GLUE에서 +2.05% 향상을 달성했다. 흥미로운 점은 top-k 선택보다 importance resampling이 더 좋다는 것이다. Top-k는 weight가 가장 높은 \(k\)개만 선택하므로 다양성이 부족해지지만, importance resampling은 weight에 비례하여 확률적으로 샘플링하므로 관련성(relevance)과 다양성(diversity) 사이의 균형을 자동으로 맞춘다.

확장성 (Scalability)

단계 소요 시간
Importance weight 계산 4.36시간
Resampling (1억 문서) < 6분
전체 ~4.5시간

96코어 CPU 노드 1대, 96GB RAM으로 The Pile 전체(1억 문서)를 처리한다. 특히 리샘플링 단계는 선택할 문서 수와 무관하게 6분 이내로 완료되므로, 한 번 weight를 계산하면 다양한 크기의 서브셋을 즉시 만들 수 있다.

Conclusion

DSIR은 데이터 선택 문제를 importance resampling이라는 통계적 프레임워크로 형식화하고, hashed n-gram이라는 극도로 단순한 feature 공간에서 이를 효율적으로 구현했다. 전문가 큐레이션과 비슷하거나 더 좋은 성능을 GPU 없이 수 시간 만에 달성한다는 점에서 실용적 가치가 크다.

다만 hashed n-gram은 표면적 어휘 분포만 포착하므로, 깊은 의미나 구조적 품질을 반영하지 못한다는 한계가 있다. 이후 연구인 QuRating은 LLM 기반 품질 판단을, DsDm은 학습 알고리즘 인식(model-aware) 접근을 통해 이 한계를 각각 다른 방향에서 극복하려 시도한다.

참고 문헌




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • K8s 시리즈 06: EKS 네트워킹·보안·비용·운영
  • K8s 시리즈 05: Amazon EKS — 아키텍처와 Worker Node
  • K8s 시리즈 04: ConfigMap, Secret, Storage — 설정과 데이터 관리
  • K8s 시리즈 03: Service, Ingress — 트래픽 라우팅과 외부 접근
  • K8s 시리즈 02: Pod, Deployment, Job, CronJob — K8s 워크로드 총정리