DeiT#

DeiT#

import torch
import torch.nn as nn
import torch.nn.functional as F


class PatchEmbedding(nn.Module):
    def __init__(self, img_size, in_channels=3, patch_size=16, embbeding_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size # 16x16
        self.n_patches = (img_size // patch_size) **2 # number of patches in image

        self.proj = nn.Conv2d(in_channels,
                              embbeding_dim,
                              kernel_size=self.patch_size,
                              stride=self.patch_size)

        self.cls_token = nn.Parameter(torch.rand(1, 1, embbeding_dim))
        self.dist_token = nn.Parameter(torch.rand(1, 1, embbeding_dim))
        self.position_embedding = nn.Parameter(torch.rand(1, 2 + self.n_patches, embbeding_dim))

    def forward(self, x):
        n, c, h, w = x.shape

        x = self.proj(x) # (batch, embedding_dim, 14, 14)                
        x = x.flatten(2) # (batch, embedding_dim, n_patches)        
        x = x.transpose(1, 2) # (batch, n_patches, embedding_dim) 


        # Expand the class token to the full batch
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        dist_token = self.dist_token.expand(x.shape[0], -1, -1)


        # add class token, dist token        
        x = torch.cat([cls_token, dist_token, x], dim=1) # (batch, n_patches + 2, embedding_dim)  

        
        # add position embedding        
        position_embedding = self.position_embedding.expand(x.shape[0], -1, -1)
        x = x + position_embedding

        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = self.dim // n_heads
        self.scale = self.head_dim ** -0.5


        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)



    def forward(self, x):
        B, N, C = x.shape # (b, 198, 768)

        qkv = self.qkv(x) # (b, 198, 768*3)
        qkv = qkv.reshape(B, N, 3, self.n_heads, C // self.n_heads) # (b, 198, 768*3) -> (b, 198, 3, 12, 64)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (b, 198, 3, 12, 64) -> (3, b, 12, 198, 96)

        q, k, v = qkv[0], qkv[1], qkv[2] # (b, 12, 198, 64)

        # q * k
        attention = (q @ k.transpose(-2, -1)) * self.scale # (8, 12, 198, 198) * scale
        attention = attention.softmax(dim=-1)
        attention = self.attn_drop(attention)

        # attention * v
        attention = (attention @ v).transpose(1, 2).reshape(B, N, C) # (b, 198, 768)
        attention = self.proj(attention)
        attention = self.proj_drop(attention)
        
        return attention


class MLP(nn.Module):
    def __init__(self, dim, expansion=4, p=0.):
        super().__init__()

        self.fc1 = nn.Linear(dim, dim*expansion)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(p)
        self.fc2 = nn.Linear(dim*expansion, dim)
        

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)        

        return x


class EncoderBlock(nn.Module):
    def __init__(self, dim, n_heads, expansion=4, qkv_bias=True, p=0., attn_p=0.):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(dim)
        self.attention = MultiHeadAttention(dim, 
                                            n_heads, 
                                            qkv_bias,
                                            attn_p=attn_p,
                                            proj_p=p)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, expansion=expansion, p=0.1)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        return x


class MLPHead(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super().__init__()

        self.norm1 = nn.LayerNorm(embedding_dim)
        self.fc1 = nn.Linear(embedding_dim, num_classes)
        
    def forward(self, x):
        x = self.norm1(x)
        x = self.fc1(x)

        return x


class DeiT(nn.Module):
    def __init__(self, 
                img_size=224,
                patch_size=16, 
                in_channels=3, 
                num_classes=1000, 
                embbeding_dim=768, 
                depth=12, 
                n_heads=8, 
                expansion=4, 
                qkv_bias=True,
                p=0.,
                attn_p=0.,
                is_training=True):
        super().__init__()
        

        self.is_training = is_training


        self.patch_embedding = PatchEmbedding(img_size=img_size,
                                              in_channels=in_channels,
                                              patch_size=patch_size,
                                              embbeding_dim=embbeding_dim)


        self.enc_blocks = nn.ModuleList([EncoderBlock(dim=embbeding_dim,
                                                    n_heads=n_heads,
                                                    expansion=expansion,
                                                    qkv_bias=qkv_bias,
                                                    p=p,
                                                    attn_p=attn_p)
                                                    for _ in range(depth)])
        
        
        self.mlp_cls = MLPHead(embedding_dim=embbeding_dim,
                                num_classes=num_classes)
        

        self.mlp_dist = MLPHead(embedding_dim=embbeding_dim,
                                num_classes=num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)

        for encoder in self.enc_blocks:
            x = encoder(x)
            
        cls_token_final = x[:, 0]
        dist_token_final = x[:, 1]

        x_cls = self.mlp_cls(cls_token_final)
        x_dist = self.mlp_dist(dist_token_final)

        if self.is_training:
            return x_cls, x_dist
        else: 
            # inference
            return (x_cls + x_dist) / 2

Hard distillation global loss#

import torch
import torch.nn as nn
import torch.nn.functional as F


class Hard_Disitillation_Global_Loss(nn.Module):
    def __init__(self, teacher, alpha, tau):
        super(Hard_Disitillation_Global_Loss, self).__init__()

        self.teacher = teacher
        self.alpha = alpha
        self.tau = tau
    
    def forward(self, inputs, outputs_student, labels):        
        cls_token, dist_token = outputs_student

        with torch.no_grad():
            outputs_teacher = self.teacher(inputs)
        
        loss = ((1-self.alpha)* F.CrossEntropyLoss(cls_token, labels)) + (self.alpha * F.CrossEntropyLoss(dist_token, outputs_teacher.argmax(dim=1)))
        return loss

Train#

from model.DeiT import DeiT
from util.loss import *
from tqdm import tqdm
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.models as models

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print(device)

student = DeiT(img_size=224, 
            patch_size=16, 
            in_channels=3, 
            num_classes=1000, 
            embbeding_dim=768, 
            depth=12, 
            n_heads=8, 
            expansion=4, 
            qkv_bias=True, 
            p=0., 
            attn_p=0.,
            training=True)    
student.to(device)

teacher = models.resnet50()
teacher.to(device)

# teacher weight freeze
for params in teacher.parameters():    
    params.requires_grad = False

criterion = Hard_Disitillation_Global_Loss(teacher=teacher, alpha=0.5, tau=1).to(device)

student.train()
teacher.eval()

for epoch in range(1, EPOCHS+1):

    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)    
    for inputs, labels in loop:            
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs_student = student(inputs)
        loss = criterion(inputs, outputs_student, labels)

Author by 이명오
Edit by 김주영