PyTorch 로 ResNet 구현하기

지금까지 Pytorch 의 기초 문법과 Computer vision 분야의 대표적인 모델 Resnet 에 대해 살펴보았습니다. 이번 페이지에서는 pytorch 로 resnet 모델을 구현하는 방법에 대해 살펴보겠습니다.

  • 담당자: 이유진 님

  • 최종수정일: 21-09-29

  • 본 자료는 가짜연구소 3기 Pytorch guide 크루 활동으로 작성됨

01 주요 개념 및 argument 설명

1.1 주요 개념

다시 한번 주요 개념들을 살펴보겠습니다.

1.1.1 skip connection

ResNet에서는 두가지 형태의 skip connection을 다루게 됩니다. Identity Shortcut(Identity mapping by Shortcut)은 F(x) + x의 구조이며 element-wise addition에 학습대상 parameter가 없습니다. F(x)와 x의 차원이 다를 때에는 x의 차원을 증가시키기 위해 convolution layer를 사용하여 차원을 맞춰주며 이를 Projection Shortcut이라고도 합니다. 이는 F(x) + Wx (W: convolution)으로 구현이 되고 1x1 convolution + batch normalization의 구조를 사용해 차원을 맞춰주며, convolution이 사용되었기 때문에 학습대상 parameter가 존재하게 됩니다.

1.1.2 down sampling

down sampling이란 더 작은 이미지로 크기를 축소시키는 것 입니다. VGG net에서는 출력 feature map 크기를 줄일 때 max pooling을 사용하였지만, ResNet은 복잡도를 줄이기 위해 stride=2으로 대체하였습니다. ResNet에서는 차원이 바뀌는 블록의 첫 번째 convolutional layer에서 stride를 2로 사용하여 feature map 크기를 줄여주었습니다. 따라서 conv3 1, conv4 1, and conv5 1 에서 사용됩니다.

1.1.3 Block

ResNet에서는 34-layer까지는 Basic Block을 사용하였고, 더 깊은 구조에서는 Bottleneck구조를 사용하였습니다. Basic Block 은 3x3 convolution + 3x3 convolution의 구조를 가집니다. ResNet의 깊이가 점점 깊어지면 경우, parameter의 수가 너무 많아지기 때문에 50층 이상인 ResNet에서는 residual block으로 Basic Block대신 Bottleneck Block을 사용하여 층을 쌓게됩니다. 1x1 conv -> 3x3 conv -> 1x1 conv 으로 구성되어있으며 처음 1x1 conv에서 차원을 축소해서 3x3 layer에서는 작은 입출력 값을 갖게되어 연산 부담을 줄이고 마지막 1x1 conv에서 차원을 다시 늘려줍니다.

1.2 arguments

1.2.1 dilation

dilation은 kernel원소 사이 간격을 의미하며(기본값 1) dilation rate에 맞춰 filter 사이에 zero padding을 해서 크기를 늘려줍니다. 필터사이즈가 커지게 되면서 시야가 넓어진다는 장점이 있으며 주로 real-time segmentation에서 효과가 있다고 알려져 있습니다. ResNet 코드부분에서 특이한 점은 padding=dilation로 설정 되어있는데, 이는 input size와 output size가 동일해야할 때 주로 이렇게 설정을 합니다. 3x3 filter에서 dilation rate으로 2를 설정하면 5x5 filter와 동일한 크기가 되게되고, 따라서 그에 맞게 padding도 2로 해줘야 output size가 맞게 됍니다.

1.2.2 width, group

width와 group argument는 WideResNet과 ResNeXt 구현시 사용됩니다.

Wide ResNet의 경우, width_per_group을 사용합니다. 일반적으로 다음과 같이 표기하며 WRN_n_k, n은 total number of layers(깊이), k는 widening factors(폭) 의미합니다. k=1일때가 ResNet과 동일한 width인 경우이며 k가 1보다 큰 만큼, 일반 ResNet보다 넓이가 k배 넓게 구현됩니다. wide_resnet50_2 모델 코드를 예로 들면 wide_resnet50_2 은 총 50층이고 widening 계수(k)가 2인 모델이며 kwargs['width_per_group'] = 64 * 2로 base_width=64에 비해 k=2배로 폭이 증가됨을 볼 수 있습니다.

ResNeXt의 경우, groups와 width_per_group을 사용하여 Cardinality개념을 적용합니다. resnext50_32x4d 모델을 사용하는 코드를 보시면 kwargs['groups'] = 32로 input channel을 32개의 그룹으로 분할(Cardinality)하고kwargs['width_per_group'] = 4로 각 그룹당 4(=128/32)개의 채널으로 구성하게 됩니다.

02 ResNet 공식 코드 설명


PyTorch에 구현된 ResNet관련 공식 코드를 바탕으로 설명하였습니다. (코드에 대한 대부분의 설명은 해당 코드의 주석으로 달아두었습니다.) 해당 함수는 아래와 같이 사용가능하며, 기본 resnet과 더불어 Wide ResNet과 ResNext또한 아래와 같이 사용이 가능합니다.

# 공통 Arguments
## pretrained (bool): If True, returns a model pre-trained on ImageNet
## progress (bool): If True, displays a progress bar of the download to stderr

# 기본 ResNet 34층
def resnet34(pretrained=False, progress=True, **kwargs):
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)

# 기본 ResNet 50층
def resnet50(pretrained=False, progress=True, **kwargs):
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)

# Wide ResNet
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)

# ResNext
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
    kwargs['groups'] = 32 # input channel을 32개의 그룹으로 분할 (cardinality)
    kwargs['width_per_group'] = 4 # 각 그룹당 4(=128/32)개의 채널으로 구성.
    # 각 그룹당 channel 4의 output feautre map 생성, concatenate해서 128개로 다시 생성.
    
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    r"""
    - pretrained: pretrained된 모델 가중치를 불러오기 (saved by caffe)
    - arch: ResNet모델 이름
    - block: 어떤 block 형태 사용할지 ("Basic or Bottleneck")
    - layers: 해당 block이 몇번 사용되는지를 list형태로 넘겨주는 부분
    """
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        model.load_state_dict(state_dict)
    return model

2.1 ResNet 구조

구현할 ResNet의 각 층과 각 블록은 아래와 같은 형태를 지닙니다.

2.2 Convolution Layer

공식코드에서는 nn.Conv2d을 사용한 convolution layer 를 아래와같이 별도로 정의해두고 사용하였습니다.

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    r"""
    3x3 convolution with padding
    - in_planes: in_channels
    - out_channels: out_channels
    - bias=False: BatchNorm에 bias가 포함되어 있으므로, conv2d는 bias=False로 설정.
    """
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

kernel size가 3x3인 convolution 필터와 1x1인 필터를 따로 정의해두고 사용하였습니다. in_planes는 input data의 채널수이고 out_planes는 output data의 채널 수 입니다.

2.3 Blocks: BasicBlock, Bottleneck

2.3.1 BasicBlock

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        r"""
         - inplanes: input channel size
         - planes: output channel size
         - groups, base_width: ResNext나 Wide ResNet의 경우 사용
        """
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
            
        # Basic Block의 구조
        self.conv1 = conv3x3(inplanes, planes, stride)  # conv1에서 downsample
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        
        # short connection
        if self.downsample is not None:
            identity = self.downsample(x)
            
        # identity mapping시 identity mapping후 ReLU를 적용합니다.
        # 그 이유는, ReLU를 통과하면 양의 값만 남기 때문에 Residual의 의미가 제대로 유지되지 않기 때문입니다.
        out += identity
        out = self.relu(out)

        return out

init 생성자에서 BasicBlock의 convolution, batchnormalization, relu 등 필요한 연산들을 구조에 맞게 정의해두었으며 forward함수를 통해 feedforward 및 identitiy mapping connection을 구현하였습니다.

2.3.2 Bottleneck

class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion = 4 # 블록 내에서 차원을 증가시키는 3번째 conv layer에서의 확장계수
    
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        # ResNext나 WideResNet의 경우 사용
        width = int(planes * (base_width / 64.)) * groups
        
        # Bottleneck Block의 구조
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation) # conv2에서 downsample
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        # 1x1 convolution layer
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        # 3x3 convolution layer
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        # 1x1 convolution layer
        out = self.conv3(out)
        out = self.bn3(out)
        # skip connection
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

클래스의 구조는 위의 Basic Block과 동일하며 앞의 설명대로 conv1x1이 추가로 사용되었습니다. 또한 conv1이 아니라 conv2에서 downsampling이 이루어진게 의아할 수 있지만, 관련해서는 한 유저가 Kaming He의 답변을 공유함으로써 알 수 있습니다. 결론적으로는 큰 의미가 있는 것은 아니며 어디에 downsampling을 적용해도 무방할 것으로 보입니다.

In all experiments in the paper, the stride=2 operation is in the first 1x1 conv layer when downsampling. This might not be the best choice, as it wastes some computations of the preceding block. For example, using stride=2 in the first 1x1 conv in the first block of conv3 is equivalent to using stride=2 in the 3x3 conv in the last block of conv2. So I feel applying stride=2 to either the first 1x1 or the 3x3 conv should work. I just kept it “as is”, because we do not have enough resources to investigate every choice.

2.3.3 ResNet class

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        # default values
        self.inplanes = 64 # input feature map
        self.dilation = 1
        # stride를 dilation으로 대체할지 선택
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        
        r"""
        - 처음 입력에 적용되는 self.conv1과 self.bn1, self.relu는 모든 ResNet에서 동일 
        - 3: 입력으로 RGB 이미지를 사용하기 때문에 convolution layer에 들어오는 input의 channel 수는 3
        """
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        r"""
        - 아래부터 block 형태와 갯수가 ResNet층마다 변화
        - self.layer1 ~ 4: 필터의 개수는 각 block들을 거치면서 증가(64->128->256->512)
        - self.avgpool: 모든 block을 거친 후에는 Adaptive AvgPool2d를 적용하여 (n, 512, 1, 1)의 텐서로
        - self.fc: 이후 fc layer를 연결
        """
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, # 여기서부터 downsampling적용
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        r"""
        convolution layer 생성 함수
        - block: block종류 지정
        - planes: feature map size (input shape)
        - blocks: layers[0]와 같이, 해당 블록이 몇개 생성돼야하는지, 블록의 갯수 (layer 반복해서 쌓는 개수)
        - stride와 dilate은 고정
        """
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        
        # the number of filters is doubled: self.inplanes와 planes 사이즈를 맞춰주기 위한 projection shortcut
        # the feature map size is halved: stride=2로 downsampling
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        # 블록 내 시작 layer, downsampling 필요
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion # inplanes 업데이트
        # 동일 블록 반복
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)

결국 최종적으로 ResNet 모델은 ResNet class를 통해 사용이 됩니다. __init__에서 처음 conv1층과 마지막층(pooing과 fully connected) 이외에는 _make_layer함수로 모델의 제일 큰 단위의 층을 생성 및 정의합니다. _make_layer함수는 논문의 conv2_X, conv3_x, conv4_X, conv5_x을 구현하며 각 층에 해당하는 block을 갯수에 맞게 생성 및 연결시켜주는 역할을 합니다. 이렇게 생성자에서 정의된 층들은 forward함수로 모델에 대한 feedforward를 진행합니다. 이후 학습에서는 loss를 지정한 후 해당 모델에 대한 backpropagation을 진행하게 됩니다.

Reference