학부생이 보는 GAN

논문 링크: Generative Adversarial Network

GAN은 2014년도에 나온 논문으로 현재 많은 연구에 영향을 끼치고 있고 Yann LeCun이 혁명적인 아이디어라고 극찬했다. GAN은 Image Generation에 관한 기초 모델로 이를 활용해 늙은 사진, 안경쓴 사진 등 원하는 이미지를 만들어낼 수 있다.

Contribution

이 논문에 Contribution은 다음과 같다.

  1. 이후 연구가 활발히 진행되는 GAN의 기본적인 이론적인 개념을 제시했다.
  2. ganerate된 이미지는 하나의 지점으로 수렴하며 이 지점은 하나뿐인 global optimum이라는 것을 증명했다.

Basic Concept

“Adversarial”이라는 단어는 적대적인 이라는 뜻을 갖는다. 논문 제목에서 알 수 있듯 이 논문에서 두 네트워크는 서로 적대적인 관계에 있으며 서로 경쟁하면서 학습해 나간다.

다음 두 네트워크 Generator, Discriminator가 있다. Generator는 이미지를 만들어내는 네트워크이고 Discriminator는 이미지들이 Generator에서 만들어진 이미지인지 실제 데이터셋에 있는 실제 이미지인지 구분한다. GAN 논문에서는 이것을 지폐위조범과 경찰로 묘사했다.

지폐위조범인 Generator는 들킬 위험이 없는 위조지폐를 만드는 것이 목표다. 그리고 경찰인 Discriminator는 이 위조지폐를 찾아내는 것을 목표로 하고 있다. 이러한 상황에서 각각의 네트워크들은 자신들의 성능들을 높일 것이고 결과적으로 위조지폐가 완벽해서 실제지폐와 구분할 수 없다. (p=0.5)

수학적으로 접근해보면 다음과 같다. Generator는 우리가 갖고있는 data들의 distribution을 모사한다. real data를 \(x\), Generator가 입력으로 z를 받아 뽑은 Sample data를 \(G(z)\)라 하겠다. (z는 보통 Gaussian noise이다.) 만약 Discriminator가 잘 학습이 되었다면 \(D(x)=1, D(G(z))=0\)이 될 것이고, Generator가 학습이 잘 된다면 \(D(G(z))=1\)이 될 것이다. Discriminator는 minimum으로 Generator는 maximum으로 각각 경쟁하며 학습해서 min-max problem이다.

Loss Function

위를 수식으로 정의하면 다음과 같다.

$$min_G max_D V(D,G) = E_{x~p_{data}}[logD(x)] + E_{x~p_z(z)}[log(1-D(G(z)))]$$

이해가 잘 안된다면 극단적으로 접근하면 된다. Discriminator가 학습이 잘 되었다면 \(D(x)=1, D(G(z))=0\)가 될 것이고, 결과적으로 \(V(D,G)=0\)으로 maximum이 될 것이다. 반대로 Generator가 학습이 잘 되었다면 \(D(G(z))=1\)이 될 것이고 \(V(D,G)=-\infty\)로 minimum이 될 것이다.

GAN 논문에서 제시하고 있는 Distribution이다. 검은색 점선은 real data distribution, 초록색 점선은 Generator distribution, 보라색 점선은 Discriminator distribution이다. 초기상태 (a)에서는 비교적 Discriminator가 real data와 sample data를 잘 판별했으나 학습이 될수록 real data와 sample data의 distribution이 비슷해져 Discriminator가 각각의 입력을 받았을 때, 출력하는 예측값은 0.5가 된다.

Global Optimality \(p_g=p_{data}\)

Proposition 1. generator G가 고정되었을때 최적의 dicriminator D는

$$D^*_G(x)=\frac {p_{data}(x)}{p_{data} + p_g(x)}$$

Proof.

$$min_G max_D V(D,G) = E_{x~p_{data}}[logD(x)] + E_{x~p_z(z)}[log(1-D(G(z)))]$$ $$V(G,D)=\int_x p_{data}(x)log(D(x))dx + \int_zp_z(z)log(1-D(G(z)))dz$$ $$V(G,D)=\int_x p_{data}(x)log(D(x)) + p_z(z)log(1-D(G(z)))dz$$

어떤 \((a, b) \in R^2\setminus\{0,0\}\)에서, 함수 \(y \rightarrow alog(y) + blog(y)\)는 [0, 1]범위에서 최댓값 \(\frac{a}{a+b}\)을 갖는다.

위의 식을 다음과 같이 변형할 수 있다.

$$C(G)= max_D(G,D)$$ $$ = E_{x~p_{data}}[logD^*_G(x)] + E_{x~p_z(z)}[log(1-D^*_G(G(z)))]$$ $$ = E_{x~p_{data}}[logD^*_G(x)] + E_{x~p_z(z)}[log(1-D^*_G(x))]$$ $$ = E_{x~p_{data}}[log\frac {p_{data}(x)}{p_{data} + p_g(x)}] + E_{x~p_z(z)}[log\frac {p_{g}(x)}{p_{data} + p_g(x)}]$$

Theorem 1. \(C(G)\)의 global minimum은 오직 \(p_g=p_{data}\)뿐이고, 이때 \(C(G)=-log4\)이다.

직관적으로 생각했을 때 \(p_g=p_{data}\)이면 \(D^*_G(G)=\frac {1}{2}\)이다.

$$C(G)=E_{x~p_{data}}[-log2] + E_{x~p_z(z)}[-log2] = -log4$$

이를 다음과 같이 생각할 수 있다.

$$E_{x~data}[log\frac {p_{data}(x)}{p_{data} + p_g(x)}] + E_{x~p_g}[log\frac {p_{g}(x)}{p_{data} + p_g(x)}]$$ $$C(G)=-log(4) + KL(p_{data}||\frac{p_{data} + p_g}{2}) + KL(p_{g}||\frac{p_{data} + p_g}{2})$$ $$C(G)=-log(4) + 2*JSD(p_{data}||p_{g})$$

Jensen-Shannon divergence의 범위는 \([0, \infty]\)이고 그 최소점은 \(p_{g}=p_{data}\)이다. 따라서 C(G)의 최소값은 \(-log(4)\)이다.

Convergence of Algorithm

Proposition 2. 만약 G과 D가 gradient descent 알고리즘으로 충분히 학습된다면 D는 다음 식에서 주어진 G과 \(p_g\)에대해 optimum에 도달하게 된다.

$$ = E_{x~p_{data}}[logD^*_G(x)] + E_{x~p_z(z)}[log(1-D^*_G(G(z)))]$$

Proof

if \(f(p_g)=sup_{D\in}f_D(p_g)\) and \(f_D(p_g)\) is convex in \(p_g\) every \(D\), then \(\vartheta f_{D^*}(p_g) \in \vartheta f\) if \(D^*=argsup_{D\in D}f_D(p_g)\)

여기서 \(f_D(p_g)\)는 앞에서 살펴본 \(C(G)\)와 같다. \(C(G)\)는 JS divergence으로 convex함수이다. 이때 모든 D에서 이 식은 성립하므로 D의 optimal인 f_{D^*}(p_g)도 convex함수이다. 따라서 우리가 풀고자하는 문제가 convex함수이기 때문에 gradient descent 알고리즘을 사용하면 global optimum에 도달한다.

Limitation

앞서 살펴본 내용들을 생각한다면 혁신적인 아이디어는 맞다. 하지만 모든 초기연구가 그렇듯 한계가 있다.

Unstable

사실 Loss함수 입장에서보면 minimum이든 maximum이든 어느쪽으로가든 상관이 없다. 즉

$$min_G max_D V(D,G) = E_{x~p_{data}}[logD(x)] + E_{x~p_z(z)}[log(1-D(G(z)))]$$

여기서 Generator를 잘 학습시키는 것 대신 Discriminator를 잘 속이는 것으로 학습방향이 흘러갈수있다. 예를들어 mnist dataset에서 Generator는 Discriminator를 잘 속이기 위해 숫자 6만 만들어낸다고 하자. 그러면 Discriminator는 숫자 6이 나오면 Generator에서 나오는 것으로 판단하고 6이라는 이미지는 fake image라고 판단한다. 이후 Generator는 Discriminator의 판단을 속이기 위해 8을 만들어낼 것이고, 앞선 상황이 반복될 것이다.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • LLM 엔지니어가 알아야 할 GPU 아키텍처: Ampere → Hopper → Blackwell
  • FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling
  • FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
  • Triton 05: Flash Attention — 종합 프로젝트
  • Triton 04: Matrix Multiplication — 2D 타일링과 Autotune