본문 바로가기
Have Done/Attention

[ViTs] Going deep with Image Transformers - LayerScale(2/4)

by 에아오요이가야 2022. 10. 27.

Imports

import io
import typing
from urllib.request import urlopen

import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
tf.__version__ ##2.10.0 필수입니다!!!

 

Step 1 : CaiT 에서 사용되는 block들 만들기

1-1. The LayerScale layer

CaiT paper에서 제안된 두가지의 modification중에 하나인 LayerScale를 우선 구현해 보겠습니다.
ViT models에서 depth를 증가시킬 때, Loss Optimization이 불안정하여 수렴하지 않는 문제가 있습니다.

ResNet구조에서 많이 접하셨으리라 생각합니다!

각각의 Transformer block에 residual connection을 연결지어서 information bottleneck 구조를 갖도록 했습니다.
연결된 bottleneck 구조는 모델의 depth가 증가될 때, Loss Optimization이 수렴할 수 있도록 가이드를 제시하는 역할을 합니다.

 
𝑥𝑙=𝑥𝑙+𝑆𝐴(𝜂(𝑥𝑙))xl′=xl+SA(η(xl))
𝑥𝑙+1=𝑥𝑙+𝐹𝐹𝑁(𝜂(𝑥𝑙))

SA 는 self-attention
FFN 는 feed-forward network
𝜂η 는 the LayerNorm operator 를 각각 의미합니다.

요약하자면 다음식과 같이 적을 수 있습니다.

 

𝑥𝑙+1=𝑥𝑙+𝑆𝐴(𝜂(𝑥𝑙))+𝐹𝐹𝑁(𝜂(𝑥𝑙+𝑆𝐴(𝜂(𝑥𝑙))))xl+1=xl+SA(η(xl))+FFN(η(xl+SA(η(xl))))

 

Fixup, ReZero and SkipInit 논문에서 이전보다 낫다고 주장된 구조는 다음과 같습니다, warmup과 layer normalize를 제거함

 

𝑥𝑙=𝑥𝑙+𝛼𝑙𝑆𝐴(𝑥𝑙)xl′=xl+αlSA(xl)
𝑥𝑙+1=𝑥𝑙+𝛼𝑙𝐹𝐹𝑁(𝑥𝑙)xl+1=xl′+αl′FFN(xl′)

하지만 이또한 수렴성의 논란이 있기 때문에 새로운 구조를 소개합니다.

 

LayerScale의 개념은 아래의 식과같이 적을 수 있습니다.

기존의 Residual Vit와 다른점은 diag matrix 포함돼있다는 것 입니다!

무슨 의미 인지가 제일 중요하겠죠? - per-channel multiplication of the vector produced by each residual block , as opposed to a single scalar, Our objective is to group the updates of the weights associated with the same output channel. Formally, LayerScale is a multiplication by a diagonal matrix on output of each residual block.

 

𝑥𝑙=𝑥𝑙+𝑑𝑖𝑎𝑔(𝜆𝑙,1,...,𝜆𝑙,𝑑)×𝑆𝐴(𝜂(𝑥𝑙))xl′=xl+diag(λl,1,...,λl,d)×SA(η(xl))
𝑥𝑙+1=𝑥𝑙+𝑑𝑖𝑎𝑔(𝜆𝑙,1,...,𝜆𝑙,𝑑)×𝐹𝐹𝑁(𝜂(𝑥𝑙))

diag는 diagonal matrix 즉, 대각 행렬을 의미합니다.

정리를 다시 하면 다음과 같은 식이 됩니다.

𝑥𝑙+1=𝑥𝑙+𝑑𝑖𝑎𝑔(𝜆𝑙,1,...,𝜆𝑙,𝑑)×𝑆𝐴(𝜂(𝑥𝑙))+𝑑𝑖𝑎𝑔(𝜆𝑙,1,...,𝜆𝑙,𝑑)×𝐹𝐹𝑁(𝜂(𝑥𝑙+𝑑𝑖𝑎𝑔(𝜆𝑙,1,...,𝜆𝑙,𝑑)×𝑆𝐴(𝜂(𝑥𝑙))))xl+1=xl+diag(λl,1,...,λl,d)×SA(η(xl))+diag(λl,1′,...,λl,d′)×FFN(η(xl+diag(λl,1,...,λl,d)×SA(η(xl))))
 

수식을 통해 알수 있듯, LayerScale는 residual branches들을 control할 수 있도록 합니다.
LayerScale의 learnable parameters은 매우 작은 숫자로 초기화하여 branches들이 identity functions의 역할을 할 수 있도록 합니다.

𝜆λ는 learnable parameters이고 매우 작은 숫자로 초기화 하여 학습합니다.
(depth 18까지는 0.1, depth 24까지는 0.00001, 그 이상의 depth에 대해서는 0.000001)

이 개념은 ActNorm이나 LayerNorm과 유사하지만 residual block의 output에 적용된다는 점이 다릅니다.
그로인한 차이점은 data-dependent한 초기화되기때문에 Batch Norm과 같은 역할이 수행되지만,
학습 과정 내에서 identity function과 비슷하게 만들어 주고network가 additional parameters를 통합하도록 합니다.

The diagonal matrix additionally helps control the contributionsof the individual dimensions of the residual inputs as it is applied on a per-channel basis.

#설명은 주저리주저리 많았지만 코드는 아주 간단합니다!
class LayerScale(layers.Layer):
    """LayerScale as introduced in CaiT: https://arxiv.org/abs/2103.17239. - Going deeper with Image Transformers

    Args:
        init_values (float): LayerScale의 diagonal matrix 초기값.
        projection_dim (int): LayerScale에서 사용되는 projection dimension.
    """

    def __init__(self, init_values: float, projection_dim: int, **kwargs):
        super().__init__(**kwargs)
        #tf.ones((projection_dim,)) -> [projection_dim,1] shape의 1로 이루어진 vector [1,1,...,1]
        self.gamma = tf.Variable(init_values * tf.ones((projection_dim,)))

    def call(self, x, training=False):
        return x * self.gamma

댓글