(10.2 ~ 10.11) - colab 에 작성한 코드 옮겨둘 것.
cf.
https://github.com/FrancescoSaverioZuppichini/ViT
https://www.youtube.com/watch?v=TrdevFK_am4
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
cf.
https://89douner.tistory.com/339
1. Representation Learning 이란?
안녕하세요. 이번글에서는 representation learning이라는 개념에 대해서 설명하려고 합니다. 개인적으로 2021년 동안 논문을 살펴보면서 가장 눈에 많이 띄었던 용어가 representation learning 이었습니다.
89douner.tistory.com
https://blog.si-analytics.ai/21
라벨 스무딩(Label smoothing), When Does Label Smoothing Help?
$ \newcommand{\infdiv}{D\infdivx} \newcommand{\comz}{\mathcal{Z}} \newcommand{\vec}{\boldsymbol} $ 딥 러닝의 신뢰도를 개선하기 위한 모델 보정(calibration) 기법 소개 최근 다양한 분야에서 각광 받는 딥 러닝은 성능 면
blog.si-analytics.ai
3.1 Vition Transformer
Inductive bias
: inductvie bias 란 학습 알고리즘이 특정 패턴이나 일반화를 추론하는 데 사용하는 사전 지식이나 가정입니다.
예를 들어서 CNN 은 이미지 처리에 특화된 모델로 필터가 이미지의 작은부분들을 처리합니다. 또 cnn 계층적 구조를 가지는데 low level 부터 high level 까지 점진적으로 학습합니다.
이에반해 Vision Transformer 는 CNN 보다 더 적은 inductive bias 를 가지고 있습니다. (Robust 합니다)
하지만 이로인해서 더 많은 데이터를 요구하게 됩니다.
조금더 자세한 설명은 아래에 정리해 두었습니다.
locality : 이미지를 구성하는 특징들은, 이미지 전체가 아닌 일부 지역들에 근접한 픽셀들로만 구성되고, 근접한 픽셀들끼리만 종속성을 갖는다는 가정
translation equivariance :
https://robot-vision-develop-story.tistory.com/29
-
- Locality (국소성):
- 설명: CNN의 필터는 국소적인 영역에서 작동합니다. 즉, 작은 영역(커널)에서만 연산을 수행하여 이미지의 특정 부분을 집중적으로 처리합니다.
- Two-Dimensional Neighborhood Structure (2차원 이웃 구조):
- 설명: 이미지가 2차원 배열(행렬)로 표현되므로, CNN의 필터는 2차원 구조를 그대로 유지하며 인접한 픽셀 간의 관계를 학습합니다.
- Translation Equivariance (평행 이동 등가):
- 설명: 이미지의 한 부분이 이동해도, 필터가 동일한 특성을 추출할 수 있는 능력. 즉, 이미지의 물체가 이동하더라도 그 물체를 인식할 수 있습니다.
3.3. Vision Transformer (ViT)와 Inductive Bias
- ViT는 CNN과는 다른 구조를 가지고 있으며, CNN의 inductive bias가 상대적으로 약하게 작용
- ViT는 데이터 주도(data-driven) 학습 방식을 통해 이러한 약한 편향을 극복하며, 보다 범용적인(task-agnostic) 구조를 제공
- 그러나 다음과 같은 방법을 통해 일부 inductive bias를 도입합니다:
- MLP Layers와 Locality 및 Translation Equivariance:
- 설명:
- ViT의 MLP 레이어는 각 패치의 특성을 학습하며, self-attention 메커니즘을 통해 국소적 특성 및 위치 변화를 학습
- 이는 CNN에서와 같은 강한 국소성을 제공하지는 않음
- 2D Neighborhood Structure와 입력 패치로 자르는 과정:
- 설명:
- 이미지를 패치로 자르는 과정은 CNN의 필터가 2차원 이웃 구조를 학습하는 방식과 유사하게 작용
- Position Embedding과 Fine-Tuning:
- 설명:
- 위치 임베딩을 통해 패치의 위치 정보를 인코딩하여 이미지의 구조적 정보를 유지
- 이는 학습 후 미세 조정을 통해 더욱 최적화될 수 있음
3.2 Fine-Tunining and High Resolution
ViT 를 매우 큰 데이터 셋에 대해서 pre-train 시키고 더 작은 데이터 셋에서 downstream tasks 에 대해서 fine-tuning 합니다. 이때 BERT 류의 LM 과 유사하게 pre-train 된 prediction head 를 제거하고 zero-initialized 된 DxK feed forward layer 를 붙입니다. ( 여기서 K 는 downstream classes).
높은 해상도의 이미지로 학습하고 있을때, patch 의 크기는 동일하게 유지합니다. patch 크기는 동일한데 높은 해상도 이므로 sequence length 가 자연스럽게 길어지게 됩니다. 그렇게 되면 기존의 위치 임베딩 벡터가 더이상 유효하지 않게 됩니다.
예시:
pre-train
이미지 해상도 224x224
패치 16x16
시퀀스 길이 = (224/16) = 14*14 = 196
위치 임베딩 : 196
fine-tune시
이미지 해상도 448ㅌ 448
패치 16x16
시퀀스 길이 = (448/16) = 28*28 = 784
위치 임베딩 : 784 ( 기존 위치 임베딩이 쓸모 없게 됨)
따라서 본 논문에서는 pre-train 된 poisitonal embedding 을 2D interpoltion 합니다.
그리고 이런 resolution 조정과 patch 추출이 ViT 의 2d img 구조에 대한 Inductive bias 를 수동으로 바꾸는 유일한 지점이라고 합니다.
그냥 간단하게 말해서 크기를 키웠으니 그에 비례한 값을 찾아주는거라고 보면 된다. 만약에 이부분을 구현하게 된다해도 그렇게 어려울 것 같지는 않다.
cf.
https://daebaq27.tistory.com/108
https://darkpgmr.tistory.com/117
4. EXPERIMENTS
https://baekyeongmin.github.io/paper-review/vision-transformer-review/
이번 논문의 내 주요 포커스는 experiments 부분이다. vit 모델부분은 transformer 를 구현에 비해서 상당히 간단할 것으로 예상하고 있다. 본 논문에서 더 얻을수 있는게 있다면 어떤식으로 실험을 했는지, 무슨 데이터 셋을 사용했는지 부터 디테일한 실험적인 부분까지 알아야 나도 나중에 써먹을 수 있다고 생각이 들었다.
또 science 보다 art 에 가깝다는 말이 나오는 것으로 미루어보아, empirical 하게 좋은 방법을 써서 성능을 내거나 많은 실험적인 디테일이 들어가 있을 것으로 짐작이 되고 나 역시 그런 부분을 키워야 연구자로써 성장할 수 있지 않을까 짐작이 들어 이런부분에서 인사이트를 조금이라도 넓혀 보고자 한다.
CNN(ResNet 계열의 기존 SoTA), Vision Transformer(ViT), hybrid 3가지 모델 구조를 비교한다. 각 모델별 사전학습에서 데이터 요구량을 이해하기 위해 다양한 크기의 데이터로 사전학습 후 평가 진행한다.
4.1 Set up
Dataset.
pre-train dataset
- ImageNet - 1k classes & 1.3M images
- ImageNet-21k - 21k classes & 14M images
- JFT - 18k classes & 303M high resolution images
pre-train dataset 과 downstream 테스크의 테스트 데이터셋이 겹치지 않도록 분리함.
Fine-tuning dataset
- ImageNet: 원래 validation label + clean-up ReaL label
- CIFAR-10/100
- Oxford-IIIT Pets
- Oxford Flowers-102
19-task VTAB classification suite : 테스크 별로 적은 양의 데이터(1000개의 학습셋)가 존재하는 다양한 테스크, 테스크는 크게 3가지 분류로 나눈다.(Transfer learning 때 사용한다)
1. Natural: 위의 CIFAR, Pets와 같이 일반적인 데이터
2.Specialized: medical, satellite 이미지
3.Structured: 위치와 같이 기하학적인 이해가 필요한 테스크
Training & Fine-tuning
-Training
- 모든 모델은 Adam 으로 훈련한다.Adam의 하이퍼파라미터는 β1 = 0.9, β2 = 0.999로 설정합니다. ( 가장 흔히 쓰이는 값)
- batch size : 4096
- weight decay : 0.1 ( 매우 높은 수치-> 이는 곧 transfrer learning 에 더 유리하게 작용함 / L2 penalty와 동일함 loss 에 penalty 를 주는방식 )
+ Appendix D.1 을 보면 나와있듯이 일반적인 경우와 다르게 ResNet 에서는 SGD 를 더 많이 쓰지만 본논문의 실험에서는 adam 이 성능이 더 좋았다고 합니다.
* Appendix B.1
1. We use a linear learning rate warmup and decay, see Appendix B.1 for details. -> Table 3 의 LR decay 를 보면 정확한 수치를 알 수 있습니다. (https://velog.io/@melan/Learning-rate-warm-up-w-decay / https://eagle705.github.io/Learning-rate-warmup-scheduling/)
2. ImageNet 에서 처음모델을 학습할때 strong regularization 을 사용하였다.
3. Dropout 은 모든 dense layer 마다 사용하였다 ( 예외 : qkv- projection , positional embedding patch)
4. 224x224 resolution(해상도) 로 모두 학습시켰다.
* 여기서 다시 볼 점들 : 1. Adam vs SGD 2. Weight decay 3. linear learning rate warmup and decay 4. LR decay
5.
1.
**높은 가중치 감쇠 (Weight Decay)**를 사용하는 것의 장점은 다음과 같습니다:
1. 과적합 방지
- 과적합 방지 (Regularization):
- Weight Decay는 모델이 훈련 데이터에 과적합되는 것을 방지하는 정규화 기법 중 하나입니다. 가중치 값을 줄임으로써 모델이 훈련 데이터에 너무 잘 맞추려는 경향을 줄이고, 더 일반화된 성능을 갖도록 합니다.
2. 가중치의 안정성
- 가중치의 안정성 (Weight Stability):
- 높은 가중치 감쇠는 모델의 가중치 값이 지나치게 커지는 것을 막아줍니다. 이는 특히 딥러닝 모델에서 중요한데, 가중치 값이 너무 커지면 학습 과정이 불안정해질 수 있습니다.
3. 학습 과정의 안정성
- 학습 과정의 안정성 (Training Stability):
- 가중치가 작을수록 그레이디언트 폭발(Gradient Explosion) 현상을 피할 수 있습니다. 이는 학습 과정이 더 안정적이고, 더 예측 가능하게 만듭니다.
4. 모델의 단순화
- 모델의 단순화 (Model Simplification):
- Weight Decay는 모델의 복잡성을 줄이는 데 기여할 수 있습니다. 더 작은 가중치는 더 단순한 모델 구조를 의미하며, 이는 해석 가능성을 높이고, 예측을 더 투명하게 만듭니다.
5. 전이 학습에서의 유리함
- 전이 학습에서의 유리함 (Advantages in Transfer Learning):
- 높은 가중치 감쇠는 모델이 사전 학습된 데이터셋에서 새로운 데이터셋으로 전이될 때 더 나은 성능을 발휘할 수 있게 도와줍니다. 이는 새로운 데이터셋에서 과적합을 피하면서도 사전 학습된 지식을 유지할 수 있게 합니다.
요약
높은 가중치 감쇠 (Weight Decay)를 사용하는 것은 모델의 일반화 성능을 높이고, 학습 과정의 안정성을 강화하며, 과적합을 방지하는 데 매우 유용합니다. 이를 통해 더 단순하고 해석 가능한 모델을 만들 수 있으며, 특히 전이 학습 상황에서 유리한 결과를 얻을 수 있습니다.
Q -그렇다면 왜 transfer learning 에서 유리한가?
Weight Decay가 전이 학습에서 유리한 이유를 좀 더 자세히 설명해 보겠습니다.
과적합 방지
Weight Decay는 모델의 가중치 크기를 제한하여 과적합을 방지하는 데 도움이 됩니다. 과적합은 모델이 훈련 데이터에 너무 치중하여 새로운 데이터에 대해 일반화 능력이 떨어지는 현상입니다. Weight Decay를 사용하면 모델이 복잡해지는 것을 억제하여 과적합을 줄일 수 있습니다.
가중치 값 안정화
Weight Decay는 학습 과정에서 가중치가 지나치게 커지는 것을 방지합니다. 큰 가중치는 모델을 불안정하게 만들고, 작은 변동에도 예측이 크게 변할 수 있습니다. Weight Decay는 가중치 값을 안정화하여 모델이 보다 일관된 예측을 할 수 있게 합니다.
전이 학습에서의 효과
전이 학습(Transfer Learning)은 사전 학습된 모델을 새로운 데이터셋에 맞게 미세 조정하는 과정입니다. 이때, Weight Decay는 사전 학습된 모델의 가중치를 유지하면서도 새로운 데이터셋에 맞게 조정할 수 있게 도와줍니다. 이는 모델이 사전 학습된 지식을 잃지 않으면서도 새로운 데이터셋에 적응할 수 있게 합니다.
예를 들어, 사전 학습된 모델이 w1,w2,...,wnw_1, w_2, ..., w_n이라는 가중치를 가지고 있다고 합시다. 전이 학습 과정에서 Weight Decay를 사용하면, 모델은 새로운 데이터셋에서 가중치가 과도하게 변하는 것을 방지하고, 기존의 유용한 특성을 유지할 수 있습니다.
결론
Weight Decay는 과적합을 방지하고, 가중치 값을 안정화하며, 전이 학습 과정에서 모델이 사전 학습된 지식을 유지하면서도 새로운 데이터셋에 적응할 수 있게 도와줍니다. 이로 인해 전이 학습의 성능이 향상되고, 더 일반화된 모델을 만들 수 있게 됩니다.
4. Linear Learning Rate Warmup and Decay
개념 설명
- Learning Rate Warmup:
- 학습 초기 단계에서 learning rate(학습률)를 천천히 증가시키는 방법입니다. 초기에는 작은 학습률을 사용하여 모델이 안정적으로 학습을 시작할 수 있게 하고, 이후 점차 학습률을 증가시킵니다. 이렇게 하면 초기의 불안정한 학습을 방지할 수 있습니다.
- Learning Rate Decay:
- 학습이 진행됨에 따라 learning rate를 점차 줄여가는 방법입니다. 학습 초기에 큰 학습률을 사용하여 빠르게 최적점을 향해 가다가, 학습 후반부에는 작은 학습률을 사용하여 미세한 조정을 할 수 있게 합니다. 이렇게 하면 학습 후반부에 모델이 보다 안정적으로 수렴할 수 있습니다.
왜 사용하는가?
- 초기 불안정성 방지: 학습 초기에는 가중치가 무작위로 초기화되어 있기 때문에 큰 학습률을 사용하면 모델이 불안정하게 학습될 수 있습니다. 작은 학습률로 시작하여 이러한 문제를 방지합니다.
- 빠른 수렴: 학습 초기에는 큰 학습률을 사용하여 빠르게 최적화 방향으로 이동합니다.
- 정밀한 최적화: 학습 후반부에는 작은 학습률을 사용하여 보다 정밀하게 최적화할 수 있습니다.
- Fine-tunining
º Hyper-Parameter
- Optimizer : SGD
- Batch Size : 512
- 해상도 변경 : 1. ViT-L/16 : 512 2. ViT-H/14: 518
[구현] ( from scratch )
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, img_size, patch_size, in_channels, embed_dim):
super(PatchEmbedding, self).__init__()
self.patch_size = patch_size
self.img_size = img_size
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, embed_dim, H', W')
x = x.flatten(2) # (B, embed_dim, num_patches)
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
return x
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4.):
super(TransformerBlock, self).__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x):
attn_out, _ = self.attn(x, x, x)
x = self.norm1(x + attn_out)
mlp_out = self.mlp(x)
x = self.norm2(x + mlp_out)
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size, patch_size, in_channels, embed_dim, num_heads, num_layers, num_classes):
super(VisionTransformer, self).__init__()
self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
self.transformer_blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads) for _ in range(num_layers)
])
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
x = self.patch_embedding(x)
for block in self.transformer_blocks:
x = block(x)
x = x.mean(dim=1) # Global average pooling
x = self.head(x)
return x
# 모델 초기화 예시
img_size = 224
patch_size = 16
in_channels = 3
embed_dim = 768
num_heads = 12
num_layers = 12
num_classes = 1000
model = VisionTransformer(img_size, patch_size, in_channels, embed_dim, num_heads, num_layers, num_classes)
ViT 개선방안
1. VISION TRANSFORMERS NEED REGISTERS
DeiT - III : supervised learning
openclip : text - supervised learning
DINOv2 : self-supervised Learning
Arifact 가 vision transformer 구조의 고질적 문제이고, 이에대해서 분석하고 해결하고자 함.
그렇기 때문에 서로 다른 방식으로 학습된 vision transformer 들에서도 arifact 가 보편적으로 나타난다는점 과
제안 방법론을 통해 실제로 이러한 문제점이 해결됨을 보이고자 위의 세가지 모델을 논문에서 사용.
'ai' 카테고리의 다른 글
[Multi-Modal] (3) | 2024.11.05 |
---|---|
Transformer 구현 (0) | 2024.08.11 |
[논문]Chain-of-Thought Prompt Distillation for Multimodal Named Entity Recognition and Multimodal Relation Extraction (0) | 2024.05.04 |