MobileViT V2 Code
Contents
MobileViT V2 Code#
https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/mobilevit.py
MobileVitV2Block#
class MobileVitV2Block(nn.Module):
def __init__(
self,
in_chs: int,
out_chs: Optional[int] = None,
kernel_size: int = 3,
bottle_ratio: float = 1.0,
group_size: Optional[int] = 1,
dilation: Tuple[int, int] = (1, 1),
mlp_ratio: float = 2.0,
transformer_dim: Optional[int] = None,
transformer_depth: int = 2,
patch_size: int = 8,
attn_drop: float = 0.,
drop: int = 0.,
drop_path_rate: float = 0.,
layers: LayerFn = None,
transformer_norm_layer: Callable = GroupNorm1,
**kwargs, # eat unused args
):
super(MobileVitV2Block, self).__init__()
layers = layers
groups = num_groups(group_size, in_chs)
out_chs = out_chs or in_chs
transformer_dim = transformer_dim
self.conv_kxk = layers.conv_norm_act(
in_chs, in_chs, kernel_size=kernel_size,
stride=1, groups=groups, dilation=dilation[0])
self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)
self.transformer = nn.Sequential(*[
LinearTransformerBlock(
transformer_dim,
mlp_ratio=mlp_ratio,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer
)
for _ in range(transformer_depth)
])
self.norm = transformer_norm_layer(transformer_dim)
self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False)
self.patch_size = to_2tuple(patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
patch_h, patch_w = self.patch_size
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
num_patches = num_patch_h * num_patch_w # N
if new_h != H or new_w != W:
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True)
# Local representation
x = self.conv_kxk(x)
x = self.conv_1x1(x)
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
C = x.shape[1]
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
x = x.reshape(B, C, -1, num_patches)
# Global representations
x = self.transformer(x)
x = self.norm(x)
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
x = self.conv_proj(x)
return x
LinearTransformerBlock#
class LinearTransformerBlock(nn.Module):
def __init__(
self,
embed_dim: int,
mlp_ratio: float = 2.0,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer=None,
norm_layer=None,
) -> None:
super().__init__()
act_layer = act_layer or nn.SiLU
norm_layer = norm_layer or GroupNorm1
self.norm1 = norm_layer(embed_dim)
self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop)
self.drop_path1 = DropPath(drop_path)
self.norm2 = norm_layer(embed_dim)
self.mlp = ConvMlp(
in_features=embed_dim,
hidden_features=int(embed_dim * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.drop_path2 = DropPath(drop_path)
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
if x_prev is None:
# self-attention
x = x + self.drop_path1(self.attn(self.norm1(x)))
else:
# cross-attention
res = x
x = self.norm1(x) # norm
x = self.attn(x, x_prev) # attn
x = self.drop_path1(x) + res # residual
# Feed forward network
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
LinearSelfAttention#
class LinearSelfAttention(nn.Module):
def __init__(
self,
embed_dim: int,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.qkv_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=1 + (2 * embed_dim),
bias=bias,
kernel_size=1,
)
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=embed_dim,
bias=bias,
kernel_size=1,
)
self.out_drop = nn.Dropout(proj_drop)
def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor:
# [B, C, P, N] --> [B, h + 2d, P, N]
qkv = self.qkv_proj(x)
# Project x into query, key and value
# Query --> [B, 1, P, N]
# value, key --> [B, d, P, N]
query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
# apply softmax along N dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_drop(context_scores)
# Compute context vector
# [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
out = self.out_drop(out)
return out
@torch.jit.ignore()
def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
# x --> [B, C, P, N]
# x_prev = [B, C, P, M]
batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape
q_patch_area, q_num_patches = x.shape[-2:]
assert (
kv_patch_area == q_patch_area
), "The number of pixels in a patch for query and key_value should be the same"
# compute query, key, and value
# [B, C, P, M] --> [B, 1 + d, P, M]
qk = F.conv2d(
x_prev,
weight=self.qkv_proj.weight[:self.embed_dim + 1],
bias=self.qkv_proj.bias[:self.embed_dim + 1],
)
# [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
query, key = qk.split([1, self.embed_dim], dim=1)
# [B, C, P, N] --> [B, d, P, N]
value = F.conv2d(
x,
weight=self.qkv_proj.weight[self.embed_dim + 1],
bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None,
)
# apply softmax along M dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_drop(context_scores)
# compute context vector
# [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1]
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
out = self.out_drop(out)
return out
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
if x_prev is None:
return self._forward_self_attn(x)
else:
return self._forward_cross_attn(x, x_prev=x_prev)
Author by 김주영
Edit by 김주영