Vision Transformer Code
Contents
Vision Transformer Code#
Patch embedding#
class patch_embedding(nn.Module) :
def __init__(self, patch_size, img_size, embed_size) :
super(patch_embedding, self).__init__()
self.patch_embedding = nn.Conv2d(3, embed_size,
kernel_size=patch_size,
stride=patch_size)
# cls token을 패치 앞에 하나 더 붙여줌
self.cls_token = nn.Parameter(torch.rand(1,1,embed_size))
# cls token 1개가 더 붙었기 때문에 총 patch 개수에 + 1을 해줌
self.position = nn.Parameter(torch.rand((img_size//patch_size)**2 + 1, embed_size))
def forward(self, x) :
x = self.patch_embedding(x)
x = x.flatten(2)
x = x.transpose(2,1)
ct = self.cls_token.repeat(x.shape[0], 1, 1)
x = torch.cat([ct, x],dim=1)
x += self.position
return x
Multi-head Attention#
class multi_head_attention(nn.Module) :
def __init__(self, embed_size, num_head, dropout_rate=0.1) :
super(multi_head_attention, self).__init__()
self.q = nn.Linear(embed_size, embed_size)
self.k = nn.Linear(embed_size, embed_size)
self.v = nn.Linear(embed_size, embed_size)
self.fc = nn.Linear(embed_size, embed_size)
self.dropout = nn.Dropout(dropout_rate)
self.num_head = num_head
self.embed_size = embed_size
def qkv_reshape(self, value, num_head) :
b, n, emb = value.size()
dim = emb // num_head
return value.view(b, num_head, n, dim)
def forward(self, x) :
q = self.qkv_reshape(self.q(x), self.num_head)
k = self.qkv_reshape(self.k(x), self.num_head)
v = self.qkv_reshape(self.v(x), self.num_head)
qk = torch.matmul(q, k.transpose(3,2))
att = F.softmax(qk / (self.embed_size ** (1/2)), dim=-1)
att = torch.matmul(att, v)
b, h, n, d = att.size()
x = att.view(b, n, h*d)
x = self.fc(x)
x = self.dropout(x)
return x
MLP#
class MLP(nn.Module) :
def __init__(self, embed_size, expansion, dropout_rate):
super(MLP, self).__init__()
self.fc1 = nn.Linear(embed_size, embed_size*expansion)
self.fc2 = nn.Linear(embed_size*expansion, embed_size)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x) :
x = self.fc1(x)
x = self.gelu(x)
x = self.fc2(x)
x = self.dropout(x)
return x
Encoder Block#
class EncoderBlock(nn.Module) :
def __init__(self,
embed_size,
num_head,
expansion,
dropout_rate):
super(EncoderBlock, self).__init__()
self.skip_connection1 = skip_connection(
nn.Sequential(
nn.LayerNorm(embed_size),
multi_head_attention(embed_size, num_head, dropout_rate=0.1)
)
)
self.skip_connection2 = skip_connection(
nn.Sequential(
nn.LayerNorm(embed_size),
MLP(embed_size, expansion, dropout_rate=0.1)
)
)
def forward(self, x) :
x = self.skip_connection1(x)
x = self.skip_connection2(x)
return x
class skip_connection(nn.Module) :
def __init__(self, layer):
super(skip_connection, self).__init__()
self.layer = layer
def forward (self, x):
return self.layer(x) + x
Classifier Head#
class Classifier_Head(nn.Module) :
def __init__(self, embed_size, num_classes):
super(Classifier_Head, self).__init__()
self.avgpool1d = nn.AdaptiveAvgPool1d((1))
self.fc = nn.Sequential(
nn.LayerNorm(embed_size),
nn.Linear(embed_size, num_classes)
)
def forward(self, x) :
x = x.transpose(2,1)
x = self.avgpool1d(x).squeeze(2)
x = self.fc(x)
return x
ViT#
class VIT(nn.Module) :
def __init__(self,
patch_size=16,
img_size=224,
embed_size=768,
num_head = 8,
expansion = 4,
dropout_rate = 0.1,
encoder_depth = 12,
num_classes = 10) :
super(VIT, self).__init__()
self.PatchEmbedding = patch_embedding(patch_size, img_size, embed_size)
self.EncoderBlocks = self.make_layers(encoder_depth, embed_size, num_head, expansion, dropout_rate)
self.ClassifierHead = Classifier_Head(embed_size, num_classes)
def make_layers(self, encoder_depth, *args):
layers = []
for _ in range(0, encoder_depth) :
layers.append(EncoderBlock(*args))
return nn.Sequential(*layers)
def forward(self, x) :
x = self.PatchEmbedding(x)
x = self.EncoderBlocks(x)
x = self.ClassifierHead(x)
return x
Author by 임중섭
Edit by 김주영