본문 바로가기
DeepLearning/논문 리뷰

[논문 리뷰] Domain-Adversarial Training of Neural Networks

by Wanda 2022. 2. 26.

0. Abstract 

 - 생략

 

1. Introduction

- Domain adaptation(DA)란,  Source domain에서 Target domain으로 mapping해주는 shift가 존재 할때 discriminative classifier나 다른 predictor를 학습하는 것을 말한다.

 

- 이러한 shift를해주는 mapping이 존재하기 때문에 source domain 배운 classifier로 target domain에 적용할 수 있게 된다.

 

- 기존의 paper에서는 domain adaptation이 fixed한 feature에서 일을 수행한다고 말한다. 

(이게 무슨 소리냐면, 우리가 source에서 배우는 feature가 원하는 target domain에서의 변환이 가능하게 하기 위해 고정된 feature를 이용한다는 이야기다. 즉, 만약 다른 target domain이 존재한다면, domain adaptation을 할 수가 없다.)

 

-  해당 논문에서는 이러한 'domain adaptation'영역과 'deep feature learning' 영역을 결합하여 training process 안에서 수행하도록 focus를 맞춘다고 한다.

 

- 이렇게 결합하면서 얻은 feed-forward network는 두 도메인 사이에 존재하는 'shift'에 방해를 받지 않고 target domain에 적용할 수 있다고 한다.

 

* 기존 Domain adaptation에서 나온 이론에 의하면, cross-domain transfer의 알고리즘이

 input으로 들어가는 domain을 식별하지 못할때 가장 좋은 representation 이라고 이야기 한다. 

 

- 해당 논문의 목표는 domain adaptation을 learning representation에 포함시키는 것이고, 이로 인해 final classification decision은 discriminative(label 분류)하고 domain의 변화에 invariant한 features을 base로 할 수 있다.

 

- 이러한 discriminative한 성질과 invariant한 성질을 위해, 이 모델을 통해 얻는 deep feature mapping의 parameters는 label classifier의 loss를 최소화하고, domain classifier의 loss는 최대화 하도록 최적화(optimized)될 것이다. 

 

2. Related work 

 생략 

 

3. Domain Adaptation

 X: input space, Y : {0,1,......,L-1} is the set of L possible labels

 

이러한 X, Y의 조합인 X*Y로, 두개의 다른 distribution인 Source Domain, Target Domain을 생성한다.

 

 

- Unsupervised domain adaptation learning을 통해 label source sample S를, 그리고 unlabeled target sample T를 생성한다.  

- 본 모델의 궁극적인 목표는 target risk를 최소화 시켜주는 classifier η: X->Y를 생성하는 것이다. 

3.1 Domain Divergence 

-많은 접근법들은 target error(risk)를 source error와 target distribution 사이의 거리에 대한 개념으로 bound를 설정한다.

 

- 이러한 생각에는 가정을 기본으로 하는데, 바로 다음과 같다

  'source risk는 target risk와 source risk 간의 distribution이 유사할 때, 좋은 indicator가 될수 있다.'

 

- 이러한 '거리'에 대한 개념을 H-divergence란 이론을 통해 설명한다. 

 

* H-divergence 

- 여기서 'sup'이란 상한(supremum)의 줄임말로, 상계 중에서 가장 작은 값을 말한다. 

 

- 여기서 H란, source domain과 target domain에 동시에 적용되는 classifer η의 모임으로, 본 논문에서의 편의를 위해서 각각의 η은 [0,1]로 이루어진 binary classifer 이다. 

 

- 본 논문에서의 목적은 target risk를 최소화는 classifier η을 찾는 것을 목표로 한다. 즉, 본 논문에서는, H-divergence의 값이 작을수록 좋은 것이다. 

 

- 만약, 하나의 classifier η가 source domain에서 1로 분류할 확률이 1(100%)이고, target domain에서 1로 분류할 확률이 1(100%)라면, 이는 어떤 domain에 상관 없이 다 1로 분류하게 되므로, source domain과 target domain의 구분을 하지 못한다는 의미이다. 즉, H의 값이 2(1-1)  = 0이 된다. 

 

- 만약, classiier η가 source domain에서 1로 분류할 확률이 1(100%)이고, target domain에서 1로 분류할 확률이 0(0%)라면,  classifier가 source domain과 target domain을 구분할 수 있게 되는 것이다. 즉, H의 값이 2(1-0) = 2가 된다.

 

-H의 class가 symmetric(대칭)이라면, 좀더 경험적인 H-divergence를 이용할 수 있다. 

 

* Empirical H-divergence

- 달라진 점은  source domain에서 classifer가 1일 확률 대신 1- (source domain에서 0일 확률)로 바뀐 것이고 이는 H의 class가 대칭적이기 때문에 이용할 수 있다. (Q. H의 class가 대칭적이라는 말과, 1-D가 어떻게 수식적으로 관련이 있는지 모르겠음) 

 

- 만약, source domain에서 0으로 분류할 확률이 1/2( 1로 분류할 확률이 1/2), target domain에서 1로 분류할 확률이 1/2(0으로 분류할 확률이 1/2)라면, 이는 source domain과 target domain의 구분을 잘 하지 못하는 것으로, H = 2(1-(1/2+1/2)) = 0 이 된다

 

- 만약, source domain이 0로 분류할 확률이 0( 1으로 분류할 확률이 1), target domain이 1로 분류할 확률이 0( 0으로 분류할 확률이 1) 이라면, 이는 target domian과 source domain을 잘 할 수 있게 되는 것으로, H=2(1-(0+0)) = 2가 된다.

 

3.2 Proxy Distance

 

- 그러나, 이렇게 만들어진 empirical H-divergence는 일반적으로 정확하게 계산하기 어렵다.

 

- 이는 source domain과 target domain examples를 정확하게 구분하기 어렵기 때문에 발생하는 문제이다. 이를 해결하기위해, 본 논문에서는 source domain과 target domain의 expample를 근사하는 방법을 이용한다. 

 

- 이 방법을 이용하기 전에, 본 논문에서는 새로운 dataset을 만든다. 

 

* new dataset

여기서 source sample의 example들은 0으로, targe sample의 example은 1로 label된다. 이렇게 새로운 datatset U를 이용하여  train된 classifier의 risk는 Equation (1)의 "min" part와 유사하게 된다. 기존의 문제인 source와 target example를 분류하는 문제를 해결해주는 generalization error를 Ɛ로 대체하면서,  H-divergence는 다음과 같이 근사할 수 있게 된다.  

- 이를 "Proxy A-distance(PAD)" 라고 부른다. 

 

- 이러한 A-distance는 다음과 같의 정의될 수 있다.

- 여기서 A는 X의 subset이다. 이 A-distance와 H-divegence Definition 1 는 동일하다.  

 

3.3 Generalization Bound on the Target Risk

- Ben-David의 work을 통해서 H-divergence를 이것의 empirical estimate(경험적 추정)과 복잡한 상수로 upper bounded할 수 있다고 한다.

 

- 여기서 Constant complexity는  VC dimension H와 S와 T의 sample size에 의존한다. 이것과 source risk를 결합하면서, 본 논문에서는 다음과 같은 이론을 소개한다. 

 

 

 

4. Domain-Adversarial Neural Networks(DANN)

 

4.1 Example Case with a Shallow Neural Network 

생략

 

4.2 Generalization to Arbitrary Architectures

 

- 위의 정의들을 통해 prediction loss와 domain loss를 구할 수 있다. 

 

 

- 이것들의 loss로 이루어진 식 cost function은 다음과 같다.

- 본 논문에서의 목표는 위의 cost function을 optimizing하는 것인데, 이는 saddle point를 찾음으로써 가능하게 된다. 

 

- (11), (12)에서 정의된 saddle point는 아래의 식의 gradient updates에서 stationary point(미분했을때 0이 되는 값)과 같은 것을 발견할 수 있다. 

 

 

 

- 위의 Equations(13-15)의 updates는 stochastic gradient descent(SGD) updates와 굉장히 유사한데, 다만 다른 점은 (13)에서, class와 domain predictior의 gradients가 sum하는 것이 아닌 subtract한다. 

 

- 운이 좋게도, 위와 같은 감소하는 형태는 'gradient reversal layer(GRL)'을 통해서 수행될 수 있다.

 

* Gradient reversal layer(GRL)

 

  1) parameter가 존재하지 않는다 

  2) forward propagation에서는, indentitiy transformation(항등 변환)역활을 한다

  3) Backward propagation에서는, 나온 gradient 값에 -1를 곱해준다.

 

- 이러한 GRL은 feature extractor과 domain classifier 사이에 넣어준다.

 

- 수학적으로 표현하기 위해, 해당 논문에서는 이러한 GRL를 "pseudo-function" R(x)를 도입한다.

 

-이를 통해 완성되는 식은 다음과 같다. 

- 다음은 전체적인 모델의 구조를 나타낸 것이다. 

 

 

5. Experiments

  - 해당 논문에서 나온 실험 중, 하나의 experiment를 소개하겠다.

- 다음의 데이터은 'inter-twinning moons' toy problem을 나타낸 사진이다. 데이터셋에 대한 설명을 하자면, source sample의 경우 빨간색과 초록색(라벨링이 된 것임)을 나타낸것이고, target sample(라벨링이 되지 않음)은 검은색 점들을 나타낸 것이다. target sample은 source sample에 35도 회전을 해 준 것이다.

 

- 첫번째 열은 classification 결과를 나타낸 것이다. standard NN를 이용할 경우 source sample에 대해서는 classification이 잘 일어났지만 target sample 중 하나인 D 지점에서는 제대로 classification이 일어나지 않은 것을 확인 할 수 있다. 아래 DANN의 경우는 target sample에서도 classification이 잘 일어난 것을 확인할 수 가 있다.

 

- 두번째 열은 PCA 분석을 나타낸 것이다.  standard NN의 경우, PCA 분석이 source sample에 맞춰져 있어서 이를 target sample에 적용할 경우 제대로 PCA 분석이 일어나지 않은 것을 확인 할 수 있다. DANN에서는 위의 PCA보다 좀더 선형적으로 분석이 이루어져, target sample에 대해서도 제대로 분석이 일어난 것을 확인 할 수가 있다. 

 

- 세번째 열은 모델이 얼마나 domain classification을 못했는지를 나타낸 것이다(domain adaptation을 잘 하려면, 우리는 domain을 인식하지 못해야 한다!). 둘 다 제대로 을 하지는 못하지만, DANN모델이 더 domain을 인지하지 못하는 것을 확인할 수가 있다.