논문 읽기/AutoEncoder

[논문 리뷰] Contrastive Masked Autoencoders are Stronger Vision Learners [Z.Huang et al.]

Hyuk42 2022. 9. 6. 17:24

1. Introduction

MAE 등의 masked image modeling (MIM)은 그 단순성과 풍부한 representation을 학습할 수 있는 능력 때문에 기존 self-supervised learning method보다 downstream task에서 좋은 성과를 보였고, 그 덕에 최근 self-supervised learning 분야에서 각광받고 있다.

 

MIM 이전에 self-supervised learning에서 가장 주목받던 contrastive learning과는 달리, MIM은 (reconstruction task를 수행하기 위해) 다른 이미지들과의 관계를 모델링 하기 보다는 인풋 이미지의 local relation 을 학습하는 데에 더 초점을 맞춘다. 그 덕에 MIM은 discriminative representation을 학습하는 것에 덜 효과적이라는 단점이 있다.

 

이에 저자들은 다음과 같은 질문을 제기한다:

'MIM method로 학습된 representation을 더 좋게 만들기 위해 contrastive learning을 이용할 수 없을까?'


이 논문 이전에 contrastive learning과 MIM의 learning objective를 결합한 시도가 있었지만, MIM만 사용했을 때 대비 성능 향상 폭이 그렇게 크지 않았다고 한다. (이는 두 방법에서 쓰이는 augmentation, training objective, model 구조 등 상당 부분이 달랐기 때문일 것이라고 추측한다.)

 

이 논문의 저자들은 이러한 점들을 보완하여 Contrastive learning과 MIM을 하나의 framework (=CMAE)에 넣어서 더 좋은 representation을 학습할 수 있도록 했다고 한다. 

 

미리보는 Contribution

  1. Contrastive Masked AutoEncoder (CMAE) 제안
    → 이를 통해 학습된 representation은 local context sensitive feature 보존 + instance dicriminativeness 모델링 가능
  2. Contrastive learning을 MIM에 추가하기 위해 (Masked feature와 공존 가능한) feature decoder 제안
  3. MIM 성능을 끌어올려 새로운 SoTA 되었음.

2. Related Works

  • Contrastive Learning
  • Mask image modeling (MIM)

3. Method

3.1 Framework

Framework of the CMAE

Framework의 전체적인 구조는 siamese network를 채택하였고, Online branch (주황색), target branch (연두색)
두 개의 브랜치로 구성되어있다.

  • Online branch - masked observation reconstruction 하도록 학습
  • Target branch - contrastive learning (instance discrimination) 하도록 학습

 

기존 MIM method와 다른 점은 인풋 이미지에 spatially shfited cropping operation을 수행한다는 것이다.
(그림에서 Pixel-shifted view를 보자.)

 

## Notation

  • I : input image
  • $x_{i}^{s}$ : token sequence
  • $p_{s}^{v}$ : positional embeddings 
  • $z_{s}^{v}$ : embedding features
  • $x^{v}$ : visible tokens

 

Online encoder ($F_{s}$)

  • Visible tokens $x_{i}^{s}$ → token embedding (by linear projection) →  positional embedding $p_{s}^{v}$ 추가
    → 이걸 sequence of transformer blocks에 투입 →  $z_{s}^{v}$ 얻음 (= $x_{i}^{s}$ → $z_{s}^{v}$로 매핑)
    $z_{s}^{v} = F_{s} (x_{s}^{v} + p_{s}^{v})$
  • Vision Transformer (ViT) 구조 채택 (MAE 설정 따른 것)
  • Pre-train 이후에는 downstream task를 위한 image representation 추출로 쓰임 
    (MAE에서 처럼 downstream task에서는 여기만 쓰이는 듯)
  • 과장을 조금 보태서 한 줄로 요약하자면 "MAE의 인코더와 거의 동일하다"

 

Target encoder ($F_{t}$)

  • Online encoder가 discriminative representation을 학습할 수 있도록 contrastive supervision을 제공하도록 도입
  • Contrastive learning에만 사용
  • 온라인 인코더 $F_{s}$와 동일한 아키텍처를 공유함
  • 학습된 representation의 의미적 무결성과 차별성을 유지하기 위해 마스킹 하지 않고 이미지 전체를 입력으로 가져감
    + NLP에서의 토큰과는 다르게 이미지 토큰은  semantice meaning이 모호함
        → 이런 모호함을 피하기 위해 contrastive learning을 할 때 global representation을 사용

             (Input 전체에 대한 representation을 사용해야 contrastive learning도 잘 되기 때문으로 추측)
             : Feature of target encoder에 mean-pooled operation 적용
               i.e. $z_{t} = {1 \over N} \sum_{j=1}^{N} F_{t}(x_{j}^{t})$
               ($x_{j}^{t}$: input token for target encoder,  $z_{t}$ : representation of the input image)
  • Online encoder와는 다르게 exponential moving average(EMA)로 파라미터 업데이트
    : $\theta_{t} \leftarrow \mu\theta_{t} + (1-\mu)\theta_{s}$  (실험에서 $\mu=0.996$)

 

Online decoder

  • latent variable $z_{s}^{v}$과 Mask token feature를 target encoder, original image의 feature space로 매핑하는 것이 목표
    $\rightarrow$ Reconstruction, contrastive learning 모두 하기 위함
  • Input: encoded visible tokens $z_{s}^{v}$, mask token $z_{s}^{m}$
  • MAE와 유사하게, positional embedding이 input token에 추가됨
  • 그러나 단순히 reconstruction만 하는 MAE와는 달리, contrastive learning까지 같이 하기 위해 디코더의 구조가
    MAE와는 조금 다르다. (two branch 구조)
    → pixel decoder, feature decoder
    1. Pixel decoder $g_{p}$
      • masked patch의 pixel을 reconstruct하기 위해 학습 ↔ patch들의 pixel $y_{m}$ 예측
        ↔ $y^{\prime}_{m}$ = $\textit{I} \times g_{p} (z_{s}^{v},z_{s}^{m})$,
             $\textit{I}$ : indicator to only select the prediction corresponding to masked tokens
      • 여러 개의 transformer block들이 쌓인(stacked) 구조
         $\rightarrow$ MAE에서의 decoder와 거의 동일
                                                 
    2. Feature decoder $g_{f}$ 
      • Contrastive learning 하기 위해 target encoder의 output과 online encoder의 output을 맞추기 위해
        online encoder의 output에서 masking된 토큰의 feature를 복구
        $\rightarrow$ Online encoder의 output의 상당수는 masking 되어 있어 그대로 contrastive learning 돌릴 경우 학습이 제대로 되지 않을 것이기 때문
      • Feature decoder는 pixel decoder와 같은 구조 가졌지만 학습 목표가 다르기 때문에 파라미터를 공유하지는 않음
      • Target encoder에서 그랬던 것 처럼 feature decoder의 output에 mean pooling operation을 적용해 whole image representation $y_{s}$로 삼고, 이 feature를 contrastive learning에 이용한다.
        $y_{s} = {1 \over N} \sum g_{f}(z_{s}^{v},z_{s}^{m})$

 

3.2 View augmentation

 

MIM pre-training task는 인풋 이미지의 single view만 이용하지만, contrastive learning은 서로 다르게 augment된 2개의 view를 이용한다. 따라서 MIM과 contrastive learning이 공존할 수 있도록 이 논문에서도 2개의 다른 view를 생성해서 각각 online branch, target branch에 넣는다.

 

Contrastive learning에서 많이 쓰이는 augmentation은 color transfer / spatial transfer가 있는데, MIM의 경우 색을 바꿀경우 성능저하가 발생했으므로 online branch 에서는 color transfer를 이용하지 않았다. (target branch에서는 trivial solution을 방지하기 위해 둘 다 사용했다고 한다.)

 

두 branch에 contrastive learning에서 하듯 두 개의 다른 random crop을 시행했으나 이는 성능 저하라는 부작용이 있었다.
→ 이는 (무작위로 crop한 영역이 서로 멀리 떨어져 있거나 / 의미론적 관련이 거의 없을 때) online encoder / target decoder 의 인풋이 크게 차이나기 때문이라고 추측된다.

 

일반적인 contrastive learning에서 온전한 paired view를 사용하는 것과 달리, MIM에서 입력의 많은 부분을 마스킹하는 작업은 그러한 격차를 증폭시켜 false positive view들을 생성할 수 있다.

결과적으로, 이러한 잘못 정렬된 positivie pair에 대해 contrastive learning을 수행하면 실제로 노이즈가 발생하고 차별적이고 의미 있는 representation의 학습을 방해한다. 이를 방지하기 위해 저자들은 pixel shifting 이라는 약한 augmentation을 제안한다. 이는 다음의 두 단계로 진행된다:


1. Original image에 resized ramdom cropping을 적용해 master image 얻는다.

2. 이러면 두 브랜치는 동일한 master image를 공유하게 되는데, 여기서 master image에 cropping location을
     살짝 이동시킴으로써 각 브랜치에 들어갈 view를 만든다.

 

ex) master image x -> shape: (w+p, h+p, 3) (w,h = width, height of target input size, p = longest shifting range allowed)
[0:w, 0:h, :] -> online branch       /     [$r_{w}:r_{w}+w, r_{h}:r_{h}+h, :]] -> target branch  ($r_{w}, r_{h}$ : indep. random values in range [0,p) )

이후 image masking, color augmentation이 각각 $x_{s}, x_{t}$ 에 적용된다.

 

3.3 Training Objective

Reconstruction loss

 

- Reconstruction task에서 normalized pixel을 target으료 사용하고, Mean Squared Error (MSE)를 loss function으로 삼아 masked patch에 대해서만 original image와의 loss를 계산한다.

 

Contrastive loss

 

- InfoNCE lsss를 사용했다. 

- "projection-prediction" , projection head를 feature decoder, target encoder에 각각 추가
   projection head with target encoder : EMA로 업데이트 

- $y_{s}$ : output of feature decoder-> projection prediction으로 $y_{s}^{p}$으로 변환 (projection을 통한 prediction으로 이해하면 될 것 같다.)

- $z_{t}$ : output of target encoder -> projection head 통과 후 $z_{t}^{p}$

- s : cosine similarity

cosine similarity

- positive pair: 동일한 이미지에서 나온  y, z

- negative pair: batch 내 다른 이미지에서 나온 y, z

 

 

Overall loss

 

- $L = L_{r} + \lambda_{c}L_{c}$

 


4. Experiment

# Pre-training

: MAE에서의 세팅을 이용함 / ImageNet-1k 의 training data로 학습

 

# Encoder structure

: ViT base model

 

 

# ImageNet classification

 

# Transfer Learning - Segmentation, Object Detection

 

Ablation Study

1. 저자들의 Contribution이 유효(?)한가? (= 성능 향상에 이바지 했는가?)

$\rightarrow$ Contrastive learning, pixel shifting augmentation, feature decoder 모두 추가되었을 때 성능 향상이 있었다.

 

2. Overall objective 에서의 가중치

$\rightarrow$ Reconstruction loss와 contrastive loss가 1:1의 가중치로 가중결합 되었을 때 가장 좋은 결과가 있었다.

 

3. Pixel-shifting

$\rightarrow$ Pixel-shifting은 분명 효과가 있었지만 (ablation study #1 참고), 너무 많이 이동시킬 경우 두 branch에 입력되는 view가 너무 달라지기 때문인지 성능이 하락한 것을 볼 수 있었다.


5. Conclusion

  • CMAE 제안 → Contrastive learning 요소 추가해 MIM 성능 향상
  • 이를 위해 input generation, architecture 부분에서 각각 novel design 제안
  • 그 결과 SoTA 달성
  • Contribution 내용

[Arxiv preprint 2022]

 

[Paper]

https://arxiv.org/abs/2207.13532

 

Contrastive Masked Autoencoders are Stronger Vision Learners

Masked image modeling (MIM) has achieved promising results on various vision tasks. However, the limited discriminability of learned representation manifests there is still plenty to go for making a stronger vision learner. Towards this goal, we propose Co

arxiv.org

 

[Code]

공개하지 않았다.... 아쉽다....

 

[통계/머신러닝 모르는 사람의 소심한 비판]

- 사실 비판이라기 보다는 코드 공개를 왜 안했는지... 아쉽....

- Pixel-shifting 과정에서 master image를 만들 때 어떤 크기로 crop을 한 것인지 알려줬으면 좀 더 좋았을 것 같다.