Swin Transformer Code#
class SwinTransformer(nn.Module):
def __init__(self, class_num=100, C=96, num_heads=[3, 6, 12, 24], window_size=7, swin_num_list=[1, 1, 3, 1],
norm=True, img_size=224, dropout=0.1, ffn_dim=384):
super(SwinTransformer, self).__init__()
self.preprocessing = PreProcessing(hid_dim=C, norm=norm, img_size=img_size)
features_list = [C, C * 2, C * 4, C * 8]
stages = nn.ModuleList([])
stage_layer = SwinTransformerLayer(C=features_list[0], num_heads=num_heads[0], window_size=window_size,
ffn_dim=ffn_dim, act_layer=nn.GELU, dropout=dropout)
stages.append(SwinTransformerBlock(stage_layer, swin_num_list[0]))
for i in range(1, 4):
stages.append(PatchMerging(features_list[i - 1]))
stage_layer = SwinTransformerLayer(C=features_list[i], num_heads=num_heads[i], window_size=window_size,
ffn_dim=ffn_dim, act_layer=nn.GELU, dropout=dropout)
stages.append(SwinTransformerBlock(stage_layer, swin_num_list[i]))
self.stages = stages
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.feature = features_list[-1]
self.head = nn.Linear(features_list[-1], class_num)
def forward(self, x):
BS, H, W, C = x.size()
x = self.preprocessing(x) # BS, L, C
for stage in self.stages:
x = stage(x)
x = x.view(BS, -1, self.feature)
x = self.avgpool(x.transpose(1, 2))
x = torch.flatten(x, 1)
x = self.head(x)
class PreProcessing(nn.Module): # patch partition, embedding,
def __init__(self, hid_dim=96, norm=True, img_size=224):
self.embed = nn.Conv2d(3, hid_dim, kernel_size=4, stride=4)
self.norm_layer = None
self.norm = norm
if self.norm:
self.norm_layer = nn.LayerNorm(hid_dim)
self.num_patches = img_size // 4
self.hid_dim = hid_dim
def forward(self, x):
BS, H, W, C = x.size()
x = self.embed(x).flatten(2).transpose(1, 2) # BS, C, L -> BS, L, C
if self.norm:
return x # [Bs, L, C]
def _clone_layer(layer, num_layers):
return nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
class SwinTransformerBlock(nn.Module):
def __init__(self, layer, num_layers):
self.layers = _clone_layer(layer, num_layers)
def forward(self ,x):
for layer in self.layers:
x = layer(x)
return x
def window_partition(x, window_size):
# B, H, W, C : x.size -> B*Window_num, window_size, window_size, C
B, H, W, C = x.size()
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(x, window_size, H, W):
# B*Window_num, window_size, window_size, C - > B, H, W, C
WN = (H//window_size)**2
B = x.size()[0]//WN
x = x.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class SwinTransformerLayer(nn.Module):
def __init__(self, C, num_heads, window_size, ffn_dim, act_layer=nn.GELU, dropout=0.1):
self.mlp1 = Mlp(C, ffn_dim, act_layer=nn.GELU, drop=dropout)
self.mlp2 = Mlp(C, ffn_dim, act_layer=nn.GELU, drop=dropout)
self.norm1 = nn.LayerNorm(C)
self.norm2 = nn.LayerNorm(C)
self.norm3 = nn.LayerNorm(C)
self.norm4 = nn.LayerNorm(C)
self.shift_size = window_size // 2
self.window_size = window_size
self.W_MSA = SwinAttention(num_heads=num_heads, C=C, dropout=dropout)
self.SW_MSA = SwinAttention(num_heads=num_heads, C=C, dropout=dropout)
def forward(self, x): # BS, L, C
BS, L, C = x.shape
S = int(math.sqrt(L))
shortcut = x
x = self.norm1(x) # BS, L, C
x_windows = self.window_to_attention(x, S, C)
attn_x = self.W_MSA(x_windows)
x = self.attention_to_og(attn_x, S, C)
x = x + shortcut
shorcut = x
x = self.norm2(x)
x = self.mlp1(x)
x = x + shortcut
shortcut = x
x = self.norm3(x)
x_windows = self.window_to_attention(x, S, C, shift=True)
x_attn = self.SW_MSA(x_windows)
x = self.attention_to_og(x, S, C, shift=True)
x = x + shortcut
shortcut = x
x = self.norm4(x)
x = self.mlp2(x)
return x + shortcut
def window_to_attention(self, x, S, C, shift=False):
x = x.view(-1, S, S, C)
if shift: # => shifted 역할
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# B, L, C 를 -> B, W, H, C로 바꿔줌
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
return x_windows
def attention_to_og(self, attn_x, S, C, shift=False):
attn_x = attn_x.view(-1, self.window_size, self.window_size, C)
x = window_reverse(attn_x, self.window_size, S, S)
if shift:
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
x = x.view(-1, S * S, C)
return x
def _get_rel_pos_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
self.relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias.unsqueeze(0)
class SwinAttention(nn.Module):
def __init__(self, num_heads, C, dropout):
self.scale = C ** -0.5
self.qkv = nn.Linear(C, C * 3, bias=True)
self.num_heads = num_heads
self.softmax = nn.Softmax(dim=-1)
self.attn_drop = nn.Dropout(0.1)
self.proj = nn.Linear(C, C)
self.proj_drop = nn.Dropout(0.1)
# from timm
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/swin_transformer.py#L134
self.win_h = 7
self.win_w = 7
# define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
# get pair-wise relative position index for each token inside the window
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w))
def forward(self, x): # BS, L, C
# x = [B, H, W, C]
B, L, C = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1,
4) # 3, B, Head, L, C_v
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-1, -2)) # dot product
attn_score = self.softmax(attn + self.relative_position_bias)
attn_score = self.attn_drop(attn_score) # L, L
# B, Head, L, C_v
out = (attn @ v).transpose(1, 2).flatten(-2) # B, L, C
out = self.proj(out)
out = self.proj_drop(out)
return out
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, act_layer=nn.GELU, drop=0.1):
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, in_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class PatchMerging(nn.Module):
def __init__(self,C):
self.proj = nn.Linear(C*4, C*2)
def forward(self ,x): # BS,H,W,C
BS, L, C = x.size()
H = int(math.sqrt(L))
x = x.view(BS, H//2, H//2, C*4)
x = self.proj(x)
x = x.view(BS, -1, C*2)
return x
Author by 임중섭
Edit by 김주영