GAN — Wasserstein GAN 및 WGAN-GP

Christian Kaindl의 사진

GAN 훈련은 어렵습니다. 모델은 수렴되지 않으며 모드 붕괴가 일반적입니다. 앞으로 나아 가기 위해 점진적으로 개선하거나 새로운 비용 기능을위한 새로운 경로를 채택 할 수 있습니다. GAN 교육에서 비용 함수가 중요합니까? 이 기사는 Wasserstein GAN (WGAN)과 WGAN-Gradient 패널티를 자세히 살펴 보는 GAN 시리즈의 일부입니다 . Wasserstein GAN의 방정식은 매우 접근하기 어렵습니다. 실제로는 매우 간단하며 예제로 설명합니다.

EM (Earth-Mover) 거리 / Wasserstein 미터법

상자 이동에 대한 간단한 연습을 완료하겠습니다. 6 개의 상자를 가져 와서 왼쪽에서 오른쪽의 점선으로 표시된 위치로 이동하려고합니다. 상자 # 1의 경우 위치 1에서 위치 7로 이동합니다. 이동 비용은 무게에 거리를 곱한 것과 같습니다. 단순화를 위해 가중치를 1로 설정합니다. 따라서 상자 # 1을 이동하는 비용은 6 (7–1)과 같습니다.

아래 그림은 두 가지 다른 이동 계획 γ를 보여줍니다. 오른쪽의 표는 상자가 이동하는 방법을 보여줍니다. 예를 들어, 첫 번째 계획에서 2 개의 상자를 위치 1에서 위치 10으로 이동하면 γ (1, 10) 항목이 2로 설정됩니다. 아래 두 플랜의 총 운송 비용은 42입니다.

그러나 모든 운송 계획이 동일한 비용을 부담하는 것은 아닙니다. 와서 스타 거리 (또는 EM 거리 ) 최저가 수송 계획의 비용이다. 아래 예에서 두 계획 모두 비용이 다르고 Wasserstein 거리 (최소 비용)는 2입니다.

설명하기 전에 복잡한 용어를 던져 보겠습니다. Wasserstein 거리는 데이터 분포 q 를 데이터 분포 p 로 변환 할 때 질량을 전송하는 최소 비용입니다 . 실제 데이터 분포 Pr 및 생성 된 데이터 분포 Pg에 대한 Wasserstein 거리는 수학적으로 모든 운송 계획 (즉, 가장 저렴한 계획의 비용)에 대한 최대 하한 (infimum)으로 정의됩니다.

WGAN 논문에서 :

Π (Pr, Pg)는 한계 값이 각각 Pr 및 Pg 인 모든 결합 분포 γ (x, y)의 집합을 나타냅니다.

수학 공식에 겁 먹지 마세요. 위의 방정식은 연속 공간에서의 예제와 동일합니다. Π 가능한 모든 운송 계획 γ를 포함 합니다.

우리는 변수 xy 를 결합하여 결합 분포를 형성합니다. γ (x, y)γ (1, 10) 은 단순히 위치 10에있는 상자의 수가 위치 1에서 가져온 것입니다. 위치 10에있는 상자의 수는 원래 임의 위치에서 가져와야합니다. , 즉 ∑ γ (*, 10) = 2. 이것은 γ (x, y) 가 각각 한계 PrPg를 가져야한다는 말과 같습니다 .

KL-Divergence 및 JS-Divergence

새로운 비용 함수를 옹호하기 전에 먼저 생성 모델에서 사용되는 두 가지 일반적인 차이, 즉 KL-Divergence와 JS-Divergence를 살펴 보겠습니다.

여기서 p 는 실제 데이터 분포이고 q 는 모델에서 추정 된 분포 입니다. 가우스 분포라고 가정 해 봅시다. 아래 다이어그램에서 우리 는 다른 평균을 갖는 p 와 몇 개의 q 를 플로팅 합니다.

아래에서 pq 사이의 해당 KL-divergence와 JS-divergence 를 0에서 35 사이의 평균 범위로 표시합니다. 예상대로 pq 가 모두 같을 때 발산은 0입니다. q 의 평균이 증가하면 발산이 증가합니다. 발산의 기울기는 결국 감소합니다. 기울기가 0에 가깝습니다. 즉, 생성기는 기울기 하강에서 아무것도 배우지 않습니다.

비판은 쉽습니다 . 실제로 GAN은 생성기보다 쉽게 ​​판별자를 최적화 할 수 있습니다. 최적의 판별기로 GAN 목적 함수를 최소화하는 것은 JS-divergence ( proof ) 를 최소화하는 것과 같습니다 . 위에서 설명한 것처럼 생성 된 이미지의 분포 q 가 Ground Truth p 에서 멀리 떨어진 경우 생성기는 거의 아무것도 학습하지 않습니다.

Arjovsky et al 2017은 GAN 문제를 수학적으로 설명하는 논문을 작성하여 다음과 같은 결론을 내 렸습니다.

  • 최적의 판별자는 생성기가 개선 할 수있는 좋은 정보를 생성합니다. 그러나 생성기가 아직 제대로 작동하지 않으면 생성기의 기울기가 줄어들고 생성기는 아무것도 배우지 않습니다 (우리가 방금 설명한 것과 동일한 결론).
원래의 GAN 논문은 이러한 기울기 소실 문제를 해결하기 위해 대체 비용 함수를 제안합니다. 그러나 Arjovsky는 새로운 기능이 모델을 불안정하게 만드는 큰 변화도를 가지고 있음을 보여줍니다.
Arjovsky는 모델을 안정화하기 위해 생성 된 이미지에 노이즈를 추가 할 것을 제안합니다.

Wasserstein 거리

노이즈를 추가하는 대신 Wasserstein GAN (WGAN)은 모든 곳에서 더 부드러운 기울기를 갖는 Wasserstein 거리를 사용하는 새로운 비용 함수를 제안합니다. WGAN은 발전기가 작동하는지 여부에 관계없이 학습합니다. 아래 다이어그램 은 GAN 및 WGAN 모두에 대한 D (X) 값에 대한 유사한 플롯을 반복합니다 . GAN (빨간색 선)의 경우 그라디언트가 감소하거나 폭발하는 영역으로 채워집니다. WGAN (파란색 선)의 경우 그라디언트가 모든 곳에서 더 매끄럽고 생성기가 좋은 이미지를 생성하지 않는 경우에도 더 잘 학습합니다.

출처

Wasserstein GAN

그러나 Wasserstein 거리에 대한 방정식은 매우 다루기 어렵습니다. Kantorovich-Rubinstein 이중성을 사용하여 계산을 단순화하여

여기서 sup 는 최소 상한이고 f 는이 제약을 따르는 1-Lipschitz 함수입니다 ( Lipschitz 제약에 대한 자세한 내용은 여기 를 참조 하십시오 ).

따라서 Wasserstein 거리를 계산하려면 1-Lipschitz 함수를 찾아야합니다. 다른 딥 러닝 문제와 마찬가지로 우리는이를 학습하기 위해 딥 네트워크를 구축 할 수 있습니다. 실제로,이 네트워크는 판별 매우 유사하다 , D 단지 시그 모이 드 함수 않고 확률보다는 스칼라 스코어를 출력한다. 이 점수는 입력 이미지가 얼마나 실제적인지로 해석 될 수 있습니다. 강화 학습에서는 상태 (입력)가 얼마나 좋은지 측정하는 가치 함수 라고합니다 . 새로운 역할을 반영하기 위해 차별자를 비평가 로 이름을 바꿉니다 . GAN과 WGAN을 나란히 보여 드리겠습니다.

GAN :

WGAN

네트워크 설계는 비평가가 출력 시그 모이 드 함수를 가지고 있지 않다는 점을 제외하면 거의 동일합니다. 주요 차이점은 비용 함수에만 있습니다.

그러나 한 가지 중요한 것이 누락되었습니다. f 는 1-Lipschitz 함수 여야합니다. 제약을 적용하기 위해 WGAN은 f 의 최대 가중치 값을 제한하기 위해 매우 간단한 클리핑을 적용합니다. 즉, 식별기 의 가중치는 하이퍼 파라미터 c에 의해 제어되는 특정 범위 내에 있어야합니다 .

연산

이제 우리는 아래의 의사 코드에 모든 것을 합칠 수 있습니다.

출처

실험

손실 메트릭과 이미지 품질 간의 상관 관계

GAN에서 손실은 이미지 품질의 척도보다는 판별자를 얼마나 잘 속이는 지 측정합니다. 아래와 같이 GAN의 발전기 손실은 이미지 품질이 향상 되어도 떨어지지 않습니다. 따라서 우리는 그 가치에서 진전을 말할 수 없습니다. 테스트 이미지를 저장하고 시각적으로 평가해야합니다. 반대로 WGAN 손실 함수는 더 바람직한 이미지 품질을 반영합니다.

출처

훈련 안정성 향상

WGAN에 대한 두 가지 중요한 기여는 다음과 같습니다.

  • 실험에서 모드 붕괴의 징후가 없으며
  • 생성기는 비평가가 잘 수행 할 때 여전히 학습 할 수 있습니다.
출처

WGAN — 문제

Lipschitz 제약

클리핑을 통해 비평가의 모델에 Lipschitz 제약 조건을 적용하여 Wasserstein 거리를 계산할 수 있습니다.

연구 논문의 인용문 : 가중치 클리핑은 Lipschitz 제약 조건을 적용하는 분명히 끔찍한 방법입니다. 클리핑 매개 변수가 크면 가중치가 한계에 도달하는 데 오랜 시간이 걸리므로 비평가를 최적 상태로 훈련하기가 더 어려워집니다. 클리핑이 작 으면 레이어 수가 많거나 배치 정규화를 사용하지 않을 때 (예 : RNN) 쉽게 그라디언트가 사라지게 될 수 있습니다. 그리고 단순성과 이미 좋은 성능으로 인해 가중치 클리핑을 고수했습니다.

WGAN의 어려움은 Lipschitz 제약 조건을 적용하는 것입니다. 클리핑은 간단하지만 몇 가지 문제가 있습니다. 모델은 여전히 ​​저품질 이미지를 생성 할 수 있으며 특히 하이퍼 파라미터 c 가 올바르게 조정되지 않은 경우 수렴되지 않습니다 .

모델 성능은이 하이퍼 파라미터에 매우 민감합니다. 아래 다이어그램에서 배치 정규화가 꺼져있을 때 c 가 0.01에서 0.1로 증가 하면 판별 기가 감소하는 그라데이션에서 폭발하는 그라데이션으로 이동합니다 .

출처

모델 용량

체중 클리핑은 체중 조절로 작동합니다. 모델 f 의 용량을 줄이고 복잡한 기능을 모델링하는 기능을 제한합니다. 아래 실험에서 첫 번째 행은 WGAN에서 추정 한 값 함수의 등고선 플롯입니다. 두 번째 행은 WGAN-GP라는 WGAN의 변형에 의해 추정됩니다. WGAN의 감소 된 용량은 개선 된 WGAN-GP가 할 수있는 반면 모델의 모드 (주황색 점)를 둘러싸는 복잡한 경계를 생성하지 못합니다.

출처

기울기 패널티가있는 Wasserstein GAN (WGAN-GP)

WGAN-GP는 가중치 클리핑 대신 기울기 패널티를 사용하여 Lipschitz 제약 조건을 적용합니다.

기울기 패널티

미분 할 수있는 함수 f 는 모든 곳에서 노름이 최대 1 인 기울기가있는 경우에만 1-Lipschitz입니다.

구체적으로, WGAN-GP 논문의 부록 A는

출처

실제 데이터와 생성 된 데이터 사이에 보간 된 점은 f에 대해 1의 기울기 노름을 가져야합니다 .

따라서 클리핑을 적용하는 대신 WGAN-GP는 그라디언트 노름이 목표 노름 값 1에서 멀어지면 모델에 페널티를줍니다.

λ는 10으로 설정됩니다. 그래디언트 노름을 계산하는 데 사용되는 점 xPgPr 사이에서 샘플링 된 모든 점 입니다. (나중에 의사 코드로 이것을 이해하는 것이 더 쉬울 것입니다.)

출처

비평가 (구별 자) 는 배치 정규화를 합니다. 배치 정규화는 동일한 배치의 샘플간에 상관 관계를 생성합니다. 이는 실험에 의해 확인 된 기울기 패널티의 효과에 영향을줍니다.

설계에 따라 일부 새로운 비용 함수는 비용 함수에 경사 패널티를 추가합니다. 일부는 순전히 기울기가 증가 할 때 모델이 오작동한다는 경험적 관찰을 기반으로합니다. 그러나 그래디언트 패널티는 바람직하지 않을 수있는 계산 복잡성을 추가하지만 더 높은 품질의 이미지를 생성합니다.

연산

샘플 포인트가 생성되는 방법과 그래디언트 패널티가 계산되는 방법을 자세히 설명하는 의사 코드를 살펴 보겠습니다.

출처

WGAN-GP 실험

WGAN-GP는 훈련 안정성을 향상시킵니다. 아래와 같이 모델 설계가 덜 최적화 된 경우 WGAN-GP는 원래 GAN 비용 함수가 실패하는 동안 여전히 좋은 결과를 생성 할 수 있습니다.

출처

다음은 다양한 방법을 사용한 시작 점수입니다. WGAN-GP 논문의 실험은 WGAN에 비해 더 나은 이미지 품질과 수렴성을 보여줍니다. 그러나 DCGAN은 약간 더 나은 이미지 품질을 보여주고 더 빠르게 수렴합니다. 그러나 WGAN-GP의 시작 점수는 수렴을 시작할 때 더 안정적입니다.

출처

그렇다면 DCGAN을 이길 수 없다면 WGAN-GP의 이점은 무엇입니까? WGAN-GP의 가장 큰 장점은 수렴성입니다. 훈련을 더 안정적으로 만들어 훈련하기가 더 쉽습니다. WGAN-GP는 모델이 더 잘 수렴하는 데 도움이되므로 생성기와 판별 자에 대해 깊은 ResNet과 같은 더 복잡한 모델을 사용할 수 있습니다. 다음은 WGAN-GP와 함께 ResNet을 사용한 시작 점수 (높을수록 좋음)입니다.

출처

Google Brain의 독립적 인 연구에서 WGAN 및 WGAN-GP는 최고의 FID 점수 중 일부를 달성했습니다 (낮을수록 좋음).

출처

추가 읽기

GAN을 더 잘 이해하려는 사람들을 위해 :

GAN — GAN의 갱스터에 대한 포괄적 인 검토 (1 부)

GAN 시리즈의 모든 기사 :

GAN — GAN 시리즈 (처음부터 끝까지)

참고

Wasserstein GAN

WGAN-GP : Wasserstein GAN 훈련 개선

Generative Adversarial Networks 훈련을위한 원칙적 방법을 향하여

Suggested posts

ML 대회에서 우승하는 방법

ML 대회에서 우승하는 방법

2021 년 머신 러닝 대회에서 우승하고 싶으신가요? 알아야 할 사항이 있습니다. 저는 2020 년 Kaggle, DrivenData, AICrowd, Zindi 및 기타 20 개 플랫폼에서 개최 된 100 개 이상의 대회 데이터베이스를 사용하여 ML 대회와 협력했습니다.

베이지안 확산 모델링을 사용한 고급 예측

베이지안 확산 모델링을 사용한 고급 예측

데이터 과학의 모든 영역에서 동적 현상을 예측하고 설명하기위한 혁신적인 모델링 솔루션에 대한 수요가 많습니다. 모델링 및 동적 현상 예측의 높은 프로필 사용 사례는 다음과 같습니다. 오픈 소스 데이터 세트에 적용된 베이지안 확산 모델링을 보여주는 종단 간 예제가 제공됩니다.