DeiT
Contents
DeiT#
DeiT#
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, img_size, in_channels=3, patch_size=16, embbeding_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size # 16x16
self.n_patches = (img_size // patch_size) **2 # number of patches in image
self.proj = nn.Conv2d(in_channels,
embbeding_dim,
kernel_size=self.patch_size,
stride=self.patch_size)
self.cls_token = nn.Parameter(torch.rand(1, 1, embbeding_dim))
self.dist_token = nn.Parameter(torch.rand(1, 1, embbeding_dim))
self.position_embedding = nn.Parameter(torch.rand(1, 2 + self.n_patches, embbeding_dim))
def forward(self, x):
n, c, h, w = x.shape
x = self.proj(x) # (batch, embedding_dim, 14, 14)
x = x.flatten(2) # (batch, embedding_dim, n_patches)
x = x.transpose(1, 2) # (batch, n_patches, embedding_dim)
# Expand the class token to the full batch
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
dist_token = self.dist_token.expand(x.shape[0], -1, -1)
# add class token, dist token
x = torch.cat([cls_token, dist_token, x], dim=1) # (batch, n_patches + 2, embedding_dim)
# add position embedding
position_embedding = self.position_embedding.expand(x.shape[0], -1, -1)
x = x + position_embedding
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = self.dim // n_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_p)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_p)
def forward(self, x):
B, N, C = x.shape # (b, 198, 768)
qkv = self.qkv(x) # (b, 198, 768*3)
qkv = qkv.reshape(B, N, 3, self.n_heads, C // self.n_heads) # (b, 198, 768*3) -> (b, 198, 3, 12, 64)
qkv = qkv.permute(2, 0, 3, 1, 4) # (b, 198, 3, 12, 64) -> (3, b, 12, 198, 96)
q, k, v = qkv[0], qkv[1], qkv[2] # (b, 12, 198, 64)
# q * k
attention = (q @ k.transpose(-2, -1)) * self.scale # (8, 12, 198, 198) * scale
attention = attention.softmax(dim=-1)
attention = self.attn_drop(attention)
# attention * v
attention = (attention @ v).transpose(1, 2).reshape(B, N, C) # (b, 198, 768)
attention = self.proj(attention)
attention = self.proj_drop(attention)
return attention
class MLP(nn.Module):
def __init__(self, dim, expansion=4, p=0.):
super().__init__()
self.fc1 = nn.Linear(dim, dim*expansion)
self.act = nn.GELU()
self.dropout = nn.Dropout(p)
self.fc2 = nn.Linear(dim*expansion, dim)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
return x
class EncoderBlock(nn.Module):
def __init__(self, dim, n_heads, expansion=4, qkv_bias=True, p=0., attn_p=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attention = MultiHeadAttention(dim,
n_heads,
qkv_bias,
attn_p=attn_p,
proj_p=p)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, expansion=expansion, p=0.1)
def forward(self, x):
x = x + self.attention(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class MLPHead(nn.Module):
def __init__(self, embedding_dim, num_classes):
super().__init__()
self.norm1 = nn.LayerNorm(embedding_dim)
self.fc1 = nn.Linear(embedding_dim, num_classes)
def forward(self, x):
x = self.norm1(x)
x = self.fc1(x)
return x
class DeiT(nn.Module):
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embbeding_dim=768,
depth=12,
n_heads=8,
expansion=4,
qkv_bias=True,
p=0.,
attn_p=0.,
is_training=True):
super().__init__()
self.is_training = is_training
self.patch_embedding = PatchEmbedding(img_size=img_size,
in_channels=in_channels,
patch_size=patch_size,
embbeding_dim=embbeding_dim)
self.enc_blocks = nn.ModuleList([EncoderBlock(dim=embbeding_dim,
n_heads=n_heads,
expansion=expansion,
qkv_bias=qkv_bias,
p=p,
attn_p=attn_p)
for _ in range(depth)])
self.mlp_cls = MLPHead(embedding_dim=embbeding_dim,
num_classes=num_classes)
self.mlp_dist = MLPHead(embedding_dim=embbeding_dim,
num_classes=num_classes)
def forward(self, x):
x = self.patch_embedding(x)
for encoder in self.enc_blocks:
x = encoder(x)
cls_token_final = x[:, 0]
dist_token_final = x[:, 1]
x_cls = self.mlp_cls(cls_token_final)
x_dist = self.mlp_dist(dist_token_final)
if self.is_training:
return x_cls, x_dist
else:
# inference
return (x_cls + x_dist) / 2
Hard distillation global loss#
import torch
import torch.nn as nn
import torch.nn.functional as F
class Hard_Disitillation_Global_Loss(nn.Module):
def __init__(self, teacher, alpha, tau):
super(Hard_Disitillation_Global_Loss, self).__init__()
self.teacher = teacher
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs_student, labels):
cls_token, dist_token = outputs_student
with torch.no_grad():
outputs_teacher = self.teacher(inputs)
loss = ((1-self.alpha)* F.CrossEntropyLoss(cls_token, labels)) + (self.alpha * F.CrossEntropyLoss(dist_token, outputs_teacher.argmax(dim=1)))
return loss
Train#
from model.DeiT import DeiT
from util.loss import *
from tqdm import tqdm
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.models as models
if torch.cuda.is_available():
device = torch.device('cuda:0')
print(device)
student = DeiT(img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embbeding_dim=768,
depth=12,
n_heads=8,
expansion=4,
qkv_bias=True,
p=0.,
attn_p=0.,
training=True)
student.to(device)
teacher = models.resnet50()
teacher.to(device)
# teacher weight freeze
for params in teacher.parameters():
params.requires_grad = False
criterion = Hard_Disitillation_Global_Loss(teacher=teacher, alpha=0.5, tau=1).to(device)
student.train()
teacher.eval()
for epoch in range(1, EPOCHS+1):
loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
for inputs, labels in loop:
inputs = inputs.to(device)
labels = labels.to(device)
outputs_student = student(inputs)
loss = criterion(inputs, outputs_student, labels)
Author by 이명오
Edit by 김주영