새소식

AI Research Topic/Implementation

[Implementation] Vision Transformer (ViT) 구현

  • -

 

 

오늘은 저번에 리뷰한 vision transformer 를 구현해 볼 생각이다.

 

ViT 에 대해 잘 모른다면 아래 링크를 통해 주요 concept 을 숙지하고 오자.

 

https://hobby-is-self-improvement.tistory.com/47

 

[Paper Review] An Image is Worth 16X16 Words: Transformers for Image Recognition at Scale

GitHub: https://github.com/google-research/vision_transformer Paper: https://arxiv.org/pdf/2010.11929.pdf 오늘 리뷰할 논문은 ICLR 2021에 소개된 Vision Transformer (ViT) 논문이다. 2012년에 발표된 AlexNet 을 기점으로 image classific

hobby-is-self-improvement.tistory.com

 

 

그렇게나 성능 좋다고 유명한 CNN 기반 모델들을 압살한 만큼 추후에 transformer 기반의 변형된 모델들이 우후죽순 나올꺼라 생각된다. 그래서 vision 분야에서 가장 근간이되는 ViT 의 세부적인 기능과 로직을 이해해야 transformer 변형 모델들을 수월하게 이해하고 활용할꺼라 판단해 구현하기로 맘 먹었다.

 

요즘 워낙 성능 좋다고 알려진 모델들은 오픈소스에서 금방금방 추가되다보니 모듈 단위로 가져다 사용하는 것에 익숙해졌다. 그래서 처음부터 구현할 일이 거의 없는데 이번 경험을 통해 모델을 바라보는 시야가 넓어지고 구현 속도가 빨라지길 희망해본다.

 

오늘 구현하면서 고려할 점은 다음과 같다.

  • image 를 patch 로 변환시 einops 같은 텐서 차원을 쉽게 조작하는 라이브러리 사용 지양 -> 연습이 목적인 만큼 최대한 직접 구현
  • 모델 구현에 중점 -> 데이터셋 처리, 학습 프레임워크 작성 등은 제외
  • 최대한 직관적으로 작성 -> 최적화는 신경쓰지 말고 최대한 직관적이고 이해하기 쉽게 작성

 

 

구현은 다음과 같은 단계로 진행될꺼다.

  1. Patchifying
  2. Linear mapping
  3. Adding classification token
  4. Adding positional embedding
  5. Build transformer encoder
  6. Classification head

 


 

첫 번째 단계는 image 를 patch 단위로 분할하는 과정이다.

 

Patch 로 분할하기에 앞서 image 를 불러와보자. 이미지로는 강아지를 사용할 생각이다. 이유는 귀여우니까.

 

image = Image.open('images/dog.jpg')

fig = plt.figure()
plt.imshow(image)
plt.show()

 

귀여운 강쥐

 

 

귀여운 강쥐 사진이 잘 불러와졌다. 이제 균등한 patch 크기로 분할할 수 있도록 정사각형 크기(224x224) 로 이미지를 전처리하고 여러 이미지를 한 번에 처리할 수 있도록 배치 차원을 추가해주자.

 

 

transformer = Compose([Resize((224, 224)), ToTensor()])

image = transformer(image)
images = image.unsqueeze(0) # batch 차원 추가

 

 

마지막으로 아무 기능도 없는 nn.Module 을 정의해서 단계별로 이 클래스에 내용을 채워나가 보자.

 

class ViT(nn.Module):
    def __init__(self):
      super(ViT, self).__init__()

 

 

이제 모든 준비가 끝났으니 본격적으로 patchifying 을 구현하자.

 

ViT 논문을 보면 patchifying 과정을 다음과 같이 설명했다.

 

 

요약하자면, RGB 이미지 한 장이 주어졌을 때, 원본 이미지 resolution 을 patch resolution 으로 나누어 패치 개수를 구하고 각 patch 의 height x width x channel 을 통해 flattened 2차원 이미지 시퀀스로 변환할 수 있다는 의미이다.

 

해당 기능은 다음과 같이 구현할 수 있다.

 

def patchify(images, patch_size):
    batch_size, channels, height, width = images.shape

    # calculate the number of patches in both dimensions
    num_patches_h = height // patch_size
    num_patches_w = width // patch_size

    # Reshape the input image into patches
    patches = images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    patches = patches.contiguous().reshape(batch_size, channels, num_patches_h * num_patches_w, patch_size, patch_size)
    patches = patches.permute(0, 2, 3, 4, 1)

    # flatten into 2D images
    # 시각화를 원할 땐 이 라인을 주석 처리 !!!
    patches = patches.contiguous().view(batch_size, num_patches_h * num_patches_w, patch_size * patch_size * channels)

    return patches # torch.Size([1, 196, 768])

 

이미지가 정사각형이 아닐 때를 대비해 패치 개수를 height 와 width 로 각각 구분해 계산한다.

 

자 이제 이미지가 패치화가 잘 되었는지 눈으로 직접 확인해보자!

 

내가 인공지능 분야중 vision 분야를 좋아하는 이유도 바로 시각적으로 모든 과정을 확인할 수 있기 때문이다.

 

패치를 시각화할 수 있도록 간단한 함수를 만들어봤다.

 

def plot_patches(patches, patch_size):
    num_patches = patches.shape[0]
    num_rows = int(num_patches ** 0.5)
    num_cols = num_patches // num_rows
    if num_patches % num_rows != 0:
        num_cols += 1

    fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10))

    for i in range(num_patches):
        row = i // num_cols
        col = i % num_cols

        patch = patches[i]
        patch = patch.reshape(patch_size, patch_size, -1)

        axs[row, col].imshow(patch)
        axs[row, col].axis('off')

    plt.tight_layout()
    plt.show()

 

 

현재 patches 는 배치 차원이 추가되어있기 때문에 차원을 제거한 이미지 형태인 channel x height x width 로 변환 후 시각화해보자.

 

patches = patchify(images, patch_size=16)
plot_patches(patches[0], patch_size=16)

 

 

et voilà!!

 

 

아주 흡족스러운 결과가 나왔다. 224x224x3 이미지가 16x16 크기 196개의 patch 로 변환된걸 확인할 수 있다.

 

 


 

이제 flattened patches 가 준비되었으니 이 패치들을 linear mapping 을 통해 변환해보자.

 

변환 시 좀 더 깔끔한 구현을 위해 PatchEmbedding 모듈을 만들어보자.

 

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        super().__init__()

        self.patch_size = patch_size
        # 2) Linear mapping
        nn.Linear(patch_size * patch_size * in_channels, emb_size)

    def forward(self, x):
        # 1) patchifying
        x = patchify(x, self.patch_size)

        return x

 

이 예제에서는 768 차원으로 매핑 값을 사용했지만, 원칙적으로 어떤 숫자든 사용할 수 있다.

 

 


Adding classification token

 

이번 단계에서는 classification token 을 추가해주자. 이 토큰은 randomly initialized 한 torch.Parameter 이며, forward 메소드에서는 batch 크기만큼 복사해서 torch.cat 을 이용해 투영된 패치들 앞에 추가한다.

 

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        super().__init__()

        self.patch_size = patch_size

        # 2) Linear mapping
        self.linear_mapper = nn.Linear(patch_size * patch_size * in_channels, emb_size)

        # 3) adding classification token
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))

    def forward(self, x):
        b, _, _, _ = x.shape

        # 1) patchifying
        x = patchify(x, self.patch_size)

        x = self.linear_mapper(x)

        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to patch embeddings
        x = torch.cat([cls_tokens, x], dim=1)

        return x

 

 


Adding positional embeddings

 

현재까지의 모델은 입력 이미지 내에서 각 패치의 위치 정보를 가지고 있지 않다. Attention is All You Need 논문에서는 사인과 코사인 파형을 이용해 위치 인코딩을 생성했지만, ViT 에서는 이를 학습으로 해결한다. 따라서, 패치의 개수 + 1 (cls token) 크기의 learnable parameter 를 생성해 투영된 패치에 위치 정보를 추가한다.

 

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, image_size = 224):
        super().__init__()

        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) **2

        # 2) Linear mapping
        self.linear_mapper = nn.Linear(patch_size * patch_size * in_channels, emb_size)

        # 3) adding classification token
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))

        # 4) adding positional embedding
        self.positions = nn.Parameter(torch.randn(self.num_patches + 1, emb_size))

    def forward(self, x):
        b, _, _, _ = x.shape

        # 1) patchifying
        x = patchify(x, self.patch_size)

        x = self.linear_mapper(x)

        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        x = torch.cat([cls_tokens, x], dim=1)

        x += self.positions

        return x

 

 


Build transformer encoder

 

이제 ViT 에서 가장 어렵고 까다로운 transformer encoder 부분이다. 자세한 구조는 아래 그림을 참고하자.

 

처음 image 를 patch embedding 으로 변환하고 해당 값을 encoder 에 입력하는 구조이다. 인코더는  layer normalization, multi-head attention, residual connections, MLP 로 구성되어있다. 이제 하나씩 차근차근 구현해보자.

 

 

Attention

Attention 은 query, key, value 를 입력받는다. 이 때, query 와 key 사이에 dot product 연산을 통해 쿼리와 키의 유사성을 나타내는 attention score 를 산출하고 이 점수를 value 에 적용하여 중요한 정보를 강조한다.

 

 

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x, mask):
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len

        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)

        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)

        return out

 

 

Residuals

 

class ResidualConnection(nn.Module):
    def __init__(self, fn):
        super.__init__()
        self.fn = fn

    def forward(self, x):
        residual = x
        x = self.fn(x)
        x += residual
        return x

 

 

MLP

 

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

 

 

마지막으로 여태 구현한 기능들을 transformer encoder block 으로 뭉쳐보자.

 

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualConnection(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualConnection(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

 

 

ViT에서는 NLP 의 원조인 Transformer 구조를 최대한 활용하기 때문에, 여러 개의 인코더 블록을 활용한다. 원하는 깊이에 맞게 모델에 인코더 블록을 추가할 수 있도록 하자.

 

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

 

 


Classification head

 

ViT 모델 구현의 마지막 단계인 classification head 이다. 이 블록은 전체 시퀀스에서 softmax 를 통해 각 클래스에 대한 확률 분포를 구하는 fully-connected layer 이다. 이 때 N개의 시퀀스 중 cls_token (첫 번째 토큰) 만 추출하여 입력한다.

 

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

 

 


마무리

 

이제 ViT 클래스에 모든 기능을 넣고 torchsummary 를 통해 최종 모델 크기를 확인해보자.

 

class ViT(nn.Sequential):
    def __init__(self,
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

 

print(summary(ViT(), (3,224,224), device='cpu'))

 

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.