SepViT: Separable Vison Transformer Code#

Helpers#

from functools import partial

import torch
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# helpers

def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# helper classes

class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = ChanLayerNorm(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.norm(x))

class OverlappingPatchEmbed(nn.Module):
    def __init__(self, dim_in, dim_out, stride = 2):
        super().__init__()
        kernel_size = stride * 2 - 1
        padding = kernel_size // 2
        self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)

    def forward(self, x):
        return self.conv(x)

class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

    def forward(self, x):
        return self.proj(x) + x

Transformer Module#

# feedforward

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(dim * mult)
        self.net = nn.Sequential(
            nn.Conv2d(dim, inner_dim, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# attention

class DSSA(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 32,
        dropout = 0.,
        window_size = 7
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.window_size = window_size
        inner_dim = dim_head * heads

        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)

        # window tokens

        self.window_tokens = nn.Parameter(torch.randn(dim))

        # prenorm and non-linearity for window tokens
        # then projection to queries and keys for window tokens

        self.window_tokens_to_qk = nn.Sequential(
            nn.LayerNorm(dim_head),
            nn.GELU(),
            Rearrange('b h n c -> b (h c) n'),
            nn.Conv1d(inner_dim, inner_dim * 2, 1),
            Rearrange('b (h c) n -> b h n c', h = heads),
        )

        # window attention

        self.window_attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        """
        einstein notation

        b - batch
        c - channels
        w1 - window size (height)
        w2 - also window size (width)
        i - sequence dimension (source)
        j - sequence dimension (target dimension to be reduced)
        h - heads
        x - height of feature map divided by window size
        y - width of feature map divided by window size
        """

        batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
        assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
        num_windows = (height // wsz) * (width // wsz)
        # x.shape : torch.Size([1, 32, 56, 56]) / torch.Size([1, 64, 28, 28]) / torch.Size([1, 128, 14, 14]) / torch.Size([1, 256, 7, 7])

        # fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention

        x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
        # x.shape : torch.Size([64, 32, 49]) / torch.Size([16, 64, 49]) / torch.Size([4, 128, 49]) / torch.Size([1, 256, 49])

        # add windowing tokens

        w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
        # w.shape : torch.Size([64, 32, 1]) / torch.Size([16, 64, 1]) / torch.Size([4, 128, 1]) / torch.Size([1, 256, 1])
        x = torch.cat((w, x), dim = -1)
        # x.shape : torch.Size([64, 32, 50]) / torch.Size([16, 64, 50]) / torch.Size([4, 128, 50]) / torch.Size([1, 256, 50])

        # project for queries, keys, value

        q, k, v = self.to_qkv(x).chunk(3, dim = 1)
        # q.shape : torch.Size([64, 32, 50]) / torch.Size([16, 64, 50]) / torch.Size([4, 128, 50]) / torch.Size([1, 256, 50])

        # split out heads

        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
        # q.shape : torch.Size([64, 1, 50, 32]) / torch.Size([16, 2, 50, 32]) / torch.Size([4, 4, 50, 32]) / torch.Size([1, 8, 50, 32])

        # scale

        q = q * self.scale

        # similarity

        dots = einsum('b h i d, b h j d -> b h i j', q, k)
        # dots.shape : torch.Size([64, 1, 50, 50]) / torch.Size([16, 2, 50, 50]) / torch.Size([4, 4, 50, 50]) / torch.Size([1, 8, 50, 50])

        # attention

        attn = self.attend(dots)
        # attn.shape : torch.Size([64, 1, 50, 50]) / torch.Size([16, 2, 50, 50]) / torch.Size([4, 4, 50, 50]) / torch.Size([1, 8, 50, 50])
        
        # aggregate values

        out = torch.matmul(attn, v)
        # out.shape : torch.Size([64, 1, 50, 32]) / torch.Size([16, 2, 50, 32]) / torch.Size([4, 4, 50, 32]) / torch.Size([1, 8, 50, 32])

        # split out windowed tokens

        window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
        # window_tokens.shape : torch.Size([64, 1, 32]) / torch.Size([16, 2, 32]) / torch.Size([4, 4, 32]) / torch.Size([1, 8, 32])
        # windowed_fmaps.shape : torch.Size([64, 1, 49, 32]) / torch.Size([16, 2, 49, 32]) / torch.Size([4, 4, 49, 32]) / torch.Size([1, 8, 49, 32])

        # early return if there is only 1 window

        if num_windows == 1:
            fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
            return self.to_out(fmap)

        # carry out the pointwise attention, the main novelty in the paper

        window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
        # window_tokens.shape : torch.Size([1, 1, 64, 32]) / torch.Size([1, 2, 16, 32]) / torch.Size([1, 4, 4, 32]) / torch.Size([1, 8, 1, 32])
        windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
        # windowed_fmaps.shape : torch.Size([1, 1, 64, 49, 32]) / torch.Size([1, 2, 16, 49, 32]) / torch.Size([1, 4, 4, 49, 32]) /  / torch.Size([1, 8, 1, 32])

        # windowed queries and keys (preceded by prenorm activation)

        w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
        # w_q.shape : torch.Size([1, 1, 64, 32]) / torch.Size([1, 2, 16, 32]) / torch.Size([1, 4, 4, 32]) / torch.Size([1, 8, 1, 32])
        # w_k.shape : torch.Size([1, 1, 64, 32]) / torch.Size([1, 2, 16, 32]) / torch.Size([1, 4, 4, 32]) / torch.Size([1, 8, 1, 32])

        # scale

        w_q = w_q * self.scale
        # similarities

        w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
        # w_dots : torch.Size([1, 1, 64, 64]) / torch.Size([1, 2, 16, 16]) / torch.Size([1, 4, 4, 4]) / torch.Size([1, 8, 1, 1])
        w_attn = self.window_attend(w_dots)
        # w_attn : torch.Size([1, 1, 64, 64]) / torch.Size([1, 2, 16, 16]) / torch.Size([1, 4, 4, 4]) / torch.Size([1, 8, 1, 1])

        # aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)

        aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
        # aggregated_windowed_fmap.shape : torch.Size([1, 1, 64, 49, 32]) / torch.Size([1, 2, 16, 49, 32]) / torch.Size([1, 4, 4, 49, 32]) / torch.Size([1, 8, 1, 49, 32])

        # fold back the windows and then combine heads for aggregation

        fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
        # fmap.shape : torch.Size([1, 32, 56, 56]) / torch.Size([1, 64, 28, 28]) / torch.Size([1, 128, 14, 14]) / torch.Size([1, 256, 7, 7])
        return self.to_out(fmap)
        # self.to_out(fmap).shape : torch.Size([1, 32, 56, 56]) / torch.Size([1, 64, 28, 28]) / torch.Size([1, 128, 14, 14]) / torch.Size([1, 256, 7, 7])

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 32,
        heads = 8,
        ff_mult = 4,
        dropout = 0.,
        norm_output = True
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for ind in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
            ]))

        self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

SepViT#

class SepViT(nn.Module):
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        heads,
        window_size = 7,
        dim_head = 32,
        ff_mult = 4,
        channels = 3,
        dropout = 0.
    ):
        super().__init__()
        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'

        num_stages = len(depth)

        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dims = (channels, *dims)
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))

        strides = (4, *((2,) * (num_stages - 1)))

        hyperparams_per_stage = [heads, window_size]
        hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
        assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))

        self.layers = nn.ModuleList([])

        for ind, ((layer_dim_in, layer_dim), layer_depth, layer_stride, layer_heads, layer_window_size) in enumerate(zip(dim_pairs, depth, strides, *hyperparams_per_stage)):
            is_last = ind == (num_stages - 1)

            self.layers.append(nn.ModuleList([
                OverlappingPatchEmbed(layer_dim_in, layer_dim, stride = layer_stride),
                PEG(layer_dim),
                Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_mult = ff_mult, dropout = dropout, norm_output = not is_last),
            ]))

        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )

    def forward(self, x):
        # x.shape : torch.Size([1, 3, 224, 224])
        for ope, peg, transformer in self.layers: 
            x = ope(x) # x.shape : torch.Size([1, 32, 56, 56]) / torch.Size([1, 64, 28, 28]) / torch.Size([1, 128, 14, 14]) / torch.Size([1, 256, 7, 7])
            x = peg(x) # x.shape : torch.Size([1, 32, 56, 56]) / torch.Size([1, 64, 28, 28]) / torch.Size([1, 128, 14, 14]) / torch.Size([1, 256, 7, 7])
            x = transformer(x) # x.shape : torch.Size([1, 32, 56, 56]) / torch.Size([1, 64, 28, 28]) / torch.Size([1, 128, 14, 14]) / torch.Size([1, 256, 7, 7])

        return self.mlp_head(x)

Main#

v = SepViT(
    num_classes = 1000,
    dim = 32,               # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
    dim_head = 32,          # attention head dimension
    heads = (1, 2, 4, 8),   # number of heads per stage
    depth = (1, 2, 6, 2),   # number of transformer blocks per stage
    window_size = 7,        # window size of DSS Attention block
    dropout = 0.1           # dropout
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

Author by 정영상
Edit by 김주영