Swin Transformer Code#

SwinTransformer#

class SwinTransformer(nn.Module):
        def __init__(self, class_num=100, C=96, num_heads=[3, 6, 12, 24], window_size=7, swin_num_list=[1, 1, 3, 1],
                     norm=True, img_size=224, dropout=0.1, ffn_dim=384):
            super(SwinTransformer, self).__init__()
            self.preprocessing = PreProcessing(hid_dim=C, norm=norm, img_size=img_size)
    
            features_list = [C, C * 2, C * 4, C * 8]
    
            stages = nn.ModuleList([])
            stage_layer = SwinTransformerLayer(C=features_list[0], num_heads=num_heads[0], window_size=window_size,
                                               ffn_dim=ffn_dim, act_layer=nn.GELU, dropout=dropout)
            stages.append(SwinTransformerBlock(stage_layer, swin_num_list[0]))
            for i in range(1, 4):
                stages.append(PatchMerging(features_list[i - 1]))
                stage_layer = SwinTransformerLayer(C=features_list[i], num_heads=num_heads[i], window_size=window_size,
                                                   ffn_dim=ffn_dim, act_layer=nn.GELU, dropout=dropout)
                stages.append(SwinTransformerBlock(stage_layer, swin_num_list[i]))
    
            self.stages = stages
            self.avgpool = nn.AdaptiveAvgPool1d(1)
            self.feature = features_list[-1]
            self.head = nn.Linear(features_list[-1], class_num)
    
        def forward(self, x):
            BS, H, W, C = x.size()
            x = self.preprocessing(x)  # BS, L, C
            for stage in self.stages:
                x = stage(x)
    
            x = x.view(BS, -1, self.feature)
    
            x = self.avgpool(x.transpose(1, 2))
            x = torch.flatten(x, 1)
            x = self.head(x)
            return

PreProcessing#

class PreProcessing(nn.Module):  # patch partition, embedding,
    def __init__(self, hid_dim=96, norm=True, img_size=224):
        super().__init__()
        self.embed = nn.Conv2d(3, hid_dim, kernel_size=4, stride=4)
        self.norm_layer = None
        self.norm = norm
        if self.norm:
            self.norm_layer = nn.LayerNorm(hid_dim)

        self.num_patches = img_size // 4

        self.hid_dim = hid_dim

    def forward(self, x):
        BS, H, W, C = x.size()

        x = self.embed(x).flatten(2).transpose(1, 2)  # BS, C, L -> BS, L, C

        if self.norm:
            self.norm_layer(x)

        return x  # [Bs, L, C]

SwinTransformerBlock#

def _clone_layer(layer, num_layers):
    return nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])


class SwinTransformerBlock(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layers = _clone_layer(layer, num_layers)

    def forward(self ,x):
        for layer in self.layers:
            x = layer(x)

        return x

SwinTransformerLayer#

def window_partition(x, window_size):
    # B, H, W, C : x.size -> B*Window_num, window_size, window_size, C
    B, H, W, C = x.size()
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows
    

def window_reverse(x, window_size, H, W):
    # B*Window_num, window_size, window_size, C - > B, H, W, C
    WN = (H//window_size)**2
    B = x.size()[0]//WN
    x = x.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class SwinTransformerLayer(nn.Module):
    def __init__(self, C, num_heads, window_size, ffn_dim, act_layer=nn.GELU, dropout=0.1):
        super().__init__()
        self.mlp1 = Mlp(C, ffn_dim, act_layer=nn.GELU, drop=dropout)
        self.mlp2 = Mlp(C, ffn_dim, act_layer=nn.GELU, drop=dropout)

        self.norm1 = nn.LayerNorm(C)
        self.norm2 = nn.LayerNorm(C)
        self.norm3 = nn.LayerNorm(C)
        self.norm4 = nn.LayerNorm(C)

        self.shift_size = window_size // 2
        self.window_size = window_size
        self.W_MSA = SwinAttention(num_heads=num_heads, C=C, dropout=dropout)
        self.SW_MSA = SwinAttention(num_heads=num_heads, C=C, dropout=dropout)

    def forward(self, x):  # BS, L, C
        BS, L, C = x.shape
        S = int(math.sqrt(L))

        shortcut = x

        x = self.norm1(x)  # BS, L, C

        x_windows = self.window_to_attention(x, S, C)

        attn_x = self.W_MSA(x_windows)

        x = self.attention_to_og(attn_x, S, C)

        x = x + shortcut

        shorcut = x

        x = self.norm2(x)
        x = self.mlp1(x)

        x = x + shortcut

        shortcut = x

        x = self.norm3(x)

        x_windows = self.window_to_attention(x, S, C, shift=True)

        x_attn = self.SW_MSA(x_windows)

        x = self.attention_to_og(x, S, C, shift=True)

        x = x + shortcut

        shortcut = x

        x = self.norm4(x)
        x = self.mlp2(x)

        return x + shortcut

    def window_to_attention(self, x, S, C, shift=False):
        x = x.view(-1, S, S, C)
        if shift: # => shifted 역할
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        # B, L, C 를 -> B, W, H, C로 바꿔줌
        x_windows = window_partition(x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        return x_windows

    def attention_to_og(self, attn_x, S, C, shift=False):
        attn_x = attn_x.view(-1, self.window_size, self.window_size, C)
        x = window_reverse(attn_x, self.window_size, S, S)
        if shift:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        x = x.view(-1, S * S, C)
        return x

SwinAttention#

def _get_rel_pos_bias(self) -> torch.Tensor:
    relative_position_bias = self.relative_position_bias_table[
        self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1)  # Wh*Ww,Wh*Ww,nH
    self.relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
    return relative_position_bias.unsqueeze(0)


class SwinAttention(nn.Module):
    def __init__(self, num_heads, C, dropout):
        super().__init__()

        self.scale = C ** -0.5

        self.qkv = nn.Linear(C, C * 3, bias=True)
        self.num_heads = num_heads

        self.softmax = nn.Softmax(dim=-1)

        self.attn_drop = nn.Dropout(0.1)

        self.proj = nn.Linear(C, C)
        self.proj_drop = nn.Dropout(0.1)
                
        # from timm
        # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/swin_transformer.py#L134
        self.win_h = 7
        self.win_w = 7
        # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
        self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
                # get pair-wise relative position index for each token inside the window
        self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w))

    def forward(self, x):  # BS, L, C
        # x = [B, H, W, C]
        B, L, C = x.shape

        qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1,
                                                                                        4)  # 3, B, Head, L, C_v

        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale

        attn = (q @ k.transpose(-1, -2))  # dot product

        attn_score = self.softmax(attn + self.relative_position_bias)
        attn_score = self.attn_drop(attn_score)  # L, L
        # B, Head, L, C_v

        out = (attn @ v).transpose(1, 2).flatten(-2)  # B, L, C

        out = self.proj(out)
        out = self.proj_drop(out)

        return out

Mlp#

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, act_layer=nn.GELU, drop=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

PatchMerging#

class PatchMerging(nn.Module):
    def __init__(self,C):
        super().__init__()
        self.proj = nn.Linear(C*4, C*2)

    def forward(self ,x): # BS,H,W,C
        BS, L, C = x.size()
        H = int(math.sqrt(L))
        x = x.view(BS, H//2, H//2, C*4)
        x = self.proj(x)
        x = x.view(BS, -1, C*2)
        return x

Author by 임중섭
Edit by 김주영