본문 바로가기
Have Done/Attention

[ViTs] Going deep with Image Transformers - Class Attention(3/4)

by 에아오요이가야 2022. 10. 27.

1-2. Class attention

The vanilla ViT는 image patches와 learnable CLS(classification) token이 서로 어떻게 상호작용하는지 모델링(학습) 하기 위해 self-attention (SA) layers를 사용합니다.

SA : image pathces와 learnable CLS token의 상호작용 학습하는 연산

CaiT의 저자는 the attention layer가 image patches와 CLS tokens들의 간의 연관성(?)을 떨어지도록 제안했습니다.
propose to decouple the attention layers responsible for attending to the image patches and the CLS tokens.

기존에 우리가 ViTs를 classification과 같은 task에 사용할때, 보통 CLS token에 representations들이 포함되도록 한뒤 task-specific heads로 넘겨주는 방식을 사용합니다.

CNN에서 통상적으로 활용되는 global average pooling을 사용하는것에 배치됩니다.

CLS token과 다른 image patches들 간의 상호작용은 일관적이게 self-attention layer를 통해 얻어집니다.
CaiT의 저자가 말했듯 이러한 형태는 꼬임효과(entangled effect)를 불러일으킵니다.

일례로, the self-attention layers는 the image patches를 학습하는 부분입니다.

반대로, the self-attention layers들은 또한, CLS token을 통해 얻어진 정보들을 요약하는 역할도 하기 때문에 객체를 학습하는데 유용합니다.

이 두가지의 꼬임 현상을 풀기 위해 저자는 다음과 같이 제안했습니다.

  • Introduce the CLS token at a later stage in the network.
    • network의 뒷부분에 CLS token을 집어넣음
  • Model the interaction between the CLS token and the representations related to the
    • CLS token과 the representations들의 상호작용이 연관되도록 modeling함

구분된 attention layer들을 통해 전해지는 image patch들 이것을 저자는 Class Attention (CA)라고 명명하였습니다.
아래의 그림이 저자의 idea를 설명하고 있습니다. (논문에서 가져온 사진입니다)

 

 

이는 CLS token을 CA layer에서 query처럼 다루었기 때문에 얻어지는 효과입니다.
CLS token과 image patch는 key로서 다음 layer에 전해집니다.

 

class ClassAttention(layers.Layer):
    """Class attention as proposed in CaiT: https://arxiv.org/abs/2103.17239. - Going deeper with Image Transformers

    Args:
        projection_dim (int): attention에서 사용되는 query, key, value의 projection dimension 
        num_heads      (int): attention heads의 갯수.
        dropout_rate (float): attention scores와 final projected outputs에서 사용될 dropout rate.
    """

    def __init__(
        self, projection_dim: int, num_heads: int, dropout_rate: float, **kwargs
    ):
        super().__init__(**kwargs)
        self.num_heads = num_heads

        head_dim = projection_dim // num_heads
        self.scale = head_dim**-0.5

        self.q = layers.Dense(projection_dim)
        self.k = layers.Dense(projection_dim)
        self.v = layers.Dense(projection_dim)
        self.attn_drop = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(projection_dim)
        self.proj_drop = layers.Dropout(dropout_rate)

    def call(self, x, training=False):
        batch_size, num_patches, num_channels = (
            tf.shape(x)[0],
            tf.shape(x)[1],
            tf.shape(x)[2],
        )

        # Query projection. `cls_token` embeddings이 queries로 사용됩니다.
        q = tf.expand_dims(self.q(x[:, 0]), axis=1)
        q = tf.reshape(q, (batch_size, 1, self.num_heads, num_channels // self.num_heads))  
       
        # Shape: (batch_size, 1, num_heads, dimension_per_head)
        q = tf.transpose(q, perm=[0, 2, 1, 3])
        scale = tf.cast(self.scale, dtype=q.dtype)
        q = q * scale

        # Key projection. Patch embeddings과 cls embedding이 keys로 사용됩니다.
        k = self.k(x)
        k = tf.reshape(k, (batch_size, num_patches, self.num_heads, num_channels // self.num_heads))  
        
        # Shape: (batch_size, num_tokens, num_heads, dimension_per_head)
        k = tf.transpose(k, perm=[0, 2, 1, 3])

        # Value projection. Patch embeddings과 cls embedding이 values로 사용됩니다.
        v = self.v(x)
        v = tf.reshape(v, (batch_size, num_patches, self.num_heads, num_channels // self.num_heads))
        v = tf.transpose(v, perm=[0, 2, 1, 3])

        #cls_token embedding과 patch embeddings의 attention scores 계산하는 부분입니다.
        attn = tf.matmul(q, k, transpose_b=True)
        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn, training)

        x_cls = tf.matmul(attn, v)
        x_cls = tf.transpose(x_cls, perm=[0, 2, 1, 3])
        x_cls = tf.reshape(x_cls, (batch_size, 1, num_channels))
        x_cls = self.proj(x_cls)
        x_cls = self.proj_drop(x_cls, training)

        return x_cls, attn

댓글