PyTorch 로 Transfer-Learning 하기

이전 챕터에서 pytorch 로 resnet 구현과 관련한 내용을 다루었습니다. 이번 노트북에서는 pytorch 로 resnet 모델을 학습하는 방법에 대해 살펴보겠습니다.

  • 담당자: 권지현 님

  • 최종수정일: 21-09-29

  • 본 자료는 가짜연구소 3기 Pytorch guide 크루 활동으로 작성됨

01 data load

본 노트북에서는 torchvision 에서 제공하는 데이터 셋을 활용합니다. torchvision 에 대한 설명은 링크 를 참조바랍니다.

데이터셋을 활용하기 위한 라이브러리를 import 하겠습니다.

# torchvision 관련 라이브러리 import

from torchvision import utils
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

사용할 데이터 셋은 STL10 입니다. STL10은 Image Classification 의 벤치마크로 10개의 라벨을 가진 데이터 셋 입니다. torchvisvion 에서는 5000개의 train 데이터와 8000개의 test 로 구성되어 있으며, datasets.STL10 매소드로 다운받을 수 있습니다.

경로를 설정한 후 train, test 데이터를 다운받습니다. 경로는 단순히 root 에 폴더를 생성하여 지정하였습니다.

transforms 은 ToTensor()로 설정합니다. transforms 에 대한 설명은 링크 를 참조 바랍니다.

import os
os.mkdir('./train')
os.mkdir('./test')

train_dataset = datasets.STL10('/train', split='train', download=True, transform=transforms.ToTensor())
test_dataset = datasets.STL10('/test', split='test', download=True, transform=transforms.ToTensor())
Files already downloaded and verified
Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to /test/stl10_binary.tar.gz
Extracting /test/stl10_binary.tar.gz to /test

다운받은 이미지에 대해 스케일링 과정이 필요합니다. transform 을 활용하여 이미지 크기를 고정하고, normalization을 진행합니다.

주어진 데이터셋의 이미지는 RGB 3개의 채널로 구성되어 있으므로, 우선 채널 별 mean 값과 std 값을 계산한 후 transform 을 정의합니다.

import numpy as np

# 채널 별 mean 계산
def get_mean(dataset):
  meanRGB = [np.mean(image.numpy(), axis=(1,2)) for image,_ in dataset]
  meanR = np.mean([m[0] for m in meanRGB])
  meanG = np.mean([m[1] for m in meanRGB])
  meanB = np.mean([m[2] for m in meanRGB])
  return [meanR, meanG, meanB]

# 채널 별 str 계산
def get_std(dataset):
  stdRGB = [np.std(image.numpy(), axis=(1,2)) for image,_ in dataset]
  stdR = np.mean([s[0] for s in stdRGB])
  stdG = np.mean([s[1] for s in stdRGB])
  stdB = np.mean([s[2] for s in stdRGB])
  return [stdR, stdG, stdB]

transforms.Compose 매소드로 trainsform 단계를 묶어서 진행할 수 있습니다. 본 노트북에서는 이미지의 크기를 임의로 128로 고정한 후, 정규화하는 과정만 진행하겠습니다. 보다 다양한 augmentation 방법은 링크 를 참조바랍니다.

train_transforms = transforms.Compose([transforms.Resize((128, 128)),
                                       transforms.ToTensor(),
                                       transforms.Normalize(get_mean(train_dataset), get_std(train_dataset))])
test_transforms = transforms.Compose([transforms.Resize((128, 128)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(get_mean(test_dataset), get_std(test_dataset))])

# trainsform 정의
train_dataset.transform = train_transforms
test_dataset.transform = test_transforms

이후 dataloader 를 정의합니다. batch size 는 임의로 128, 64로 설정해두었습니다.

# dataloader 정의
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

데이터 사용을 위한 준비가 완료되었습니다.

02 training

모델 학습을 위해 pretrained 된 resnet 50 모델을 사용하겠습니다. 해당 resnet 모델은 사전 학습된 모델로, 이미지 분류 문제를 해결할 수 있도록 규모가 큰 데이터(ImageNet)로 미리 학습된 모델을 의미합니다.

torchvision model 에서 구현된 resnet의 구조는 이전 챕터에서 다루었습니다. 관련 내용은 링크 를 참조 바랍니다.

from torchvision import models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 학습 환경 설정
model = models.resnet50(pretrained=True).to(device) # true 옵션으로 사전 학습된 모델을 로드
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth

torchsummary의 summary 매소드로 모델을 요약하여 확인할 수 있습니다. 사용할 데이터 셋은 128*128 크기의 RGB 3개의 chennel로 구성되어 있습니다. 모델의 layer 별 파라미터 개수는 다음과 같습니다.

from torchsummary import summary
summary(model, (3, 128, 128))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 64, 64]           9,408
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]           4,096
       BatchNorm2d-6           [-1, 64, 32, 32]             128
              ReLU-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
             ReLU-10           [-1, 64, 32, 32]               0
           Conv2d-11          [-1, 256, 32, 32]          16,384
      BatchNorm2d-12          [-1, 256, 32, 32]             512
           Conv2d-13          [-1, 256, 32, 32]          16,384
      BatchNorm2d-14          [-1, 256, 32, 32]             512
             ReLU-15          [-1, 256, 32, 32]               0
       Bottleneck-16          [-1, 256, 32, 32]               0
           Conv2d-17           [-1, 64, 32, 32]          16,384
      BatchNorm2d-18           [-1, 64, 32, 32]             128
             ReLU-19           [-1, 64, 32, 32]               0
           Conv2d-20           [-1, 64, 32, 32]          36,864
      BatchNorm2d-21           [-1, 64, 32, 32]             128
             ReLU-22           [-1, 64, 32, 32]               0
           Conv2d-23          [-1, 256, 32, 32]          16,384
      BatchNorm2d-24          [-1, 256, 32, 32]             512
             ReLU-25          [-1, 256, 32, 32]               0
       Bottleneck-26          [-1, 256, 32, 32]               0
           Conv2d-27           [-1, 64, 32, 32]          16,384
      BatchNorm2d-28           [-1, 64, 32, 32]             128
             ReLU-29           [-1, 64, 32, 32]               0
           Conv2d-30           [-1, 64, 32, 32]          36,864
      BatchNorm2d-31           [-1, 64, 32, 32]             128
             ReLU-32           [-1, 64, 32, 32]               0
           Conv2d-33          [-1, 256, 32, 32]          16,384
      BatchNorm2d-34          [-1, 256, 32, 32]             512
             ReLU-35          [-1, 256, 32, 32]               0
       Bottleneck-36          [-1, 256, 32, 32]               0
           Conv2d-37          [-1, 128, 32, 32]          32,768
      BatchNorm2d-38          [-1, 128, 32, 32]             256
             ReLU-39          [-1, 128, 32, 32]               0
           Conv2d-40          [-1, 128, 16, 16]         147,456
      BatchNorm2d-41          [-1, 128, 16, 16]             256
             ReLU-42          [-1, 128, 16, 16]               0
           Conv2d-43          [-1, 512, 16, 16]          65,536
      BatchNorm2d-44          [-1, 512, 16, 16]           1,024
           Conv2d-45          [-1, 512, 16, 16]         131,072
      BatchNorm2d-46          [-1, 512, 16, 16]           1,024
             ReLU-47          [-1, 512, 16, 16]               0
       Bottleneck-48          [-1, 512, 16, 16]               0
           Conv2d-49          [-1, 128, 16, 16]          65,536
      BatchNorm2d-50          [-1, 128, 16, 16]             256
             ReLU-51          [-1, 128, 16, 16]               0
           Conv2d-52          [-1, 128, 16, 16]         147,456
      BatchNorm2d-53          [-1, 128, 16, 16]             256
             ReLU-54          [-1, 128, 16, 16]               0
           Conv2d-55          [-1, 512, 16, 16]          65,536
      BatchNorm2d-56          [-1, 512, 16, 16]           1,024
             ReLU-57          [-1, 512, 16, 16]               0
       Bottleneck-58          [-1, 512, 16, 16]               0
           Conv2d-59          [-1, 128, 16, 16]          65,536
      BatchNorm2d-60          [-1, 128, 16, 16]             256
             ReLU-61          [-1, 128, 16, 16]               0
           Conv2d-62          [-1, 128, 16, 16]         147,456
      BatchNorm2d-63          [-1, 128, 16, 16]             256
             ReLU-64          [-1, 128, 16, 16]               0
           Conv2d-65          [-1, 512, 16, 16]          65,536
      BatchNorm2d-66          [-1, 512, 16, 16]           1,024
             ReLU-67          [-1, 512, 16, 16]               0
       Bottleneck-68          [-1, 512, 16, 16]               0
           Conv2d-69          [-1, 128, 16, 16]          65,536
      BatchNorm2d-70          [-1, 128, 16, 16]             256
             ReLU-71          [-1, 128, 16, 16]               0
           Conv2d-72          [-1, 128, 16, 16]         147,456
      BatchNorm2d-73          [-1, 128, 16, 16]             256
             ReLU-74          [-1, 128, 16, 16]               0
           Conv2d-75          [-1, 512, 16, 16]          65,536
      BatchNorm2d-76          [-1, 512, 16, 16]           1,024
             ReLU-77          [-1, 512, 16, 16]               0
       Bottleneck-78          [-1, 512, 16, 16]               0
           Conv2d-79          [-1, 256, 16, 16]         131,072
      BatchNorm2d-80          [-1, 256, 16, 16]             512
             ReLU-81          [-1, 256, 16, 16]               0
           Conv2d-82            [-1, 256, 8, 8]         589,824
      BatchNorm2d-83            [-1, 256, 8, 8]             512
             ReLU-84            [-1, 256, 8, 8]               0
           Conv2d-85           [-1, 1024, 8, 8]         262,144
      BatchNorm2d-86           [-1, 1024, 8, 8]           2,048
           Conv2d-87           [-1, 1024, 8, 8]         524,288
      BatchNorm2d-88           [-1, 1024, 8, 8]           2,048
             ReLU-89           [-1, 1024, 8, 8]               0
       Bottleneck-90           [-1, 1024, 8, 8]               0
           Conv2d-91            [-1, 256, 8, 8]         262,144
      BatchNorm2d-92            [-1, 256, 8, 8]             512
             ReLU-93            [-1, 256, 8, 8]               0
           Conv2d-94            [-1, 256, 8, 8]         589,824
      BatchNorm2d-95            [-1, 256, 8, 8]             512
             ReLU-96            [-1, 256, 8, 8]               0
           Conv2d-97           [-1, 1024, 8, 8]         262,144
      BatchNorm2d-98           [-1, 1024, 8, 8]           2,048
             ReLU-99           [-1, 1024, 8, 8]               0
      Bottleneck-100           [-1, 1024, 8, 8]               0
          Conv2d-101            [-1, 256, 8, 8]         262,144
     BatchNorm2d-102            [-1, 256, 8, 8]             512
            ReLU-103            [-1, 256, 8, 8]               0
          Conv2d-104            [-1, 256, 8, 8]         589,824
     BatchNorm2d-105            [-1, 256, 8, 8]             512
            ReLU-106            [-1, 256, 8, 8]               0
          Conv2d-107           [-1, 1024, 8, 8]         262,144
     BatchNorm2d-108           [-1, 1024, 8, 8]           2,048
            ReLU-109           [-1, 1024, 8, 8]               0
      Bottleneck-110           [-1, 1024, 8, 8]               0
          Conv2d-111            [-1, 256, 8, 8]         262,144
     BatchNorm2d-112            [-1, 256, 8, 8]             512
            ReLU-113            [-1, 256, 8, 8]               0
          Conv2d-114            [-1, 256, 8, 8]         589,824
     BatchNorm2d-115            [-1, 256, 8, 8]             512
            ReLU-116            [-1, 256, 8, 8]               0
          Conv2d-117           [-1, 1024, 8, 8]         262,144
     BatchNorm2d-118           [-1, 1024, 8, 8]           2,048
            ReLU-119           [-1, 1024, 8, 8]               0
      Bottleneck-120           [-1, 1024, 8, 8]               0
          Conv2d-121            [-1, 256, 8, 8]         262,144
     BatchNorm2d-122            [-1, 256, 8, 8]             512
            ReLU-123            [-1, 256, 8, 8]               0
          Conv2d-124            [-1, 256, 8, 8]         589,824
     BatchNorm2d-125            [-1, 256, 8, 8]             512
            ReLU-126            [-1, 256, 8, 8]               0
          Conv2d-127           [-1, 1024, 8, 8]         262,144
     BatchNorm2d-128           [-1, 1024, 8, 8]           2,048
            ReLU-129           [-1, 1024, 8, 8]               0
      Bottleneck-130           [-1, 1024, 8, 8]               0
          Conv2d-131            [-1, 256, 8, 8]         262,144
     BatchNorm2d-132            [-1, 256, 8, 8]             512
            ReLU-133            [-1, 256, 8, 8]               0
          Conv2d-134            [-1, 256, 8, 8]         589,824
     BatchNorm2d-135            [-1, 256, 8, 8]             512
            ReLU-136            [-1, 256, 8, 8]               0
          Conv2d-137           [-1, 1024, 8, 8]         262,144
     BatchNorm2d-138           [-1, 1024, 8, 8]           2,048
            ReLU-139           [-1, 1024, 8, 8]               0
      Bottleneck-140           [-1, 1024, 8, 8]               0
          Conv2d-141            [-1, 512, 8, 8]         524,288
     BatchNorm2d-142            [-1, 512, 8, 8]           1,024
            ReLU-143            [-1, 512, 8, 8]               0
          Conv2d-144            [-1, 512, 4, 4]       2,359,296
     BatchNorm2d-145            [-1, 512, 4, 4]           1,024
            ReLU-146            [-1, 512, 4, 4]               0
          Conv2d-147           [-1, 2048, 4, 4]       1,048,576
     BatchNorm2d-148           [-1, 2048, 4, 4]           4,096
          Conv2d-149           [-1, 2048, 4, 4]       2,097,152
     BatchNorm2d-150           [-1, 2048, 4, 4]           4,096
            ReLU-151           [-1, 2048, 4, 4]               0
      Bottleneck-152           [-1, 2048, 4, 4]               0
          Conv2d-153            [-1, 512, 4, 4]       1,048,576
     BatchNorm2d-154            [-1, 512, 4, 4]           1,024
            ReLU-155            [-1, 512, 4, 4]               0
          Conv2d-156            [-1, 512, 4, 4]       2,359,296
     BatchNorm2d-157            [-1, 512, 4, 4]           1,024
            ReLU-158            [-1, 512, 4, 4]               0
          Conv2d-159           [-1, 2048, 4, 4]       1,048,576
     BatchNorm2d-160           [-1, 2048, 4, 4]           4,096
            ReLU-161           [-1, 2048, 4, 4]               0
      Bottleneck-162           [-1, 2048, 4, 4]               0
          Conv2d-163            [-1, 512, 4, 4]       1,048,576
     BatchNorm2d-164            [-1, 512, 4, 4]           1,024
            ReLU-165            [-1, 512, 4, 4]               0
          Conv2d-166            [-1, 512, 4, 4]       2,359,296
     BatchNorm2d-167            [-1, 512, 4, 4]           1,024
            ReLU-168            [-1, 512, 4, 4]               0
          Conv2d-169           [-1, 2048, 4, 4]       1,048,576
     BatchNorm2d-170           [-1, 2048, 4, 4]           4,096
            ReLU-171           [-1, 2048, 4, 4]               0
      Bottleneck-172           [-1, 2048, 4, 4]               0
AdaptiveAvgPool2d-173           [-1, 2048, 1, 1]               0
          Linear-174                 [-1, 1000]       2,049,000
================================================================
Total params: 25,557,032
Trainable params: 25,557,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 93.59
Params size (MB): 97.49
Estimated Total Size (MB): 191.27
----------------------------------------------------------------

191 mb 정도의 가벼운 모델 구조를 확인할 수 있습니다.

다음은 모델 학습을 위한 함수를 정의하겠습니다. 가장 간단한 형태의 train 컨테이너를 구성하였습니다.

필요한 패키지를 임포트합니다.

import torch
import torch.nn as nn
from torch import optim

전이학습을 위해 필요한 파라미터를 정의합니다. lr 은 learning rate 로 0.0001 로 설정하였습니다. 많은 epoch 의 학습을 진행할 때에는 스케쥴러를 사용하면 용이하나, 이번 노트북에서는 가장 간단한 형태의 train 컨테이너를 구성하여 다루지 않았습니다.

optimizer 은 Adam 으로, loss function 은 cross entropy 로 정의합니다. loss function 은 학습 목적에 따라 다양하게 구성할 수 있으므로 참고바랍니다.

학습 횟수인 num_epochs 은 5로 설정하였습니다.

lr = 0.0001
num_epochs = 5
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_function = nn.CrossEntropyLoss().to(device)

이때 설정한 파라미터들을 (수정 예정)

params = {
    'num_epochs':num_epochs,
    'optimizer':optimizer,
    'loss_function':loss_function,
    'train_dataloader':train_dataloader,
    'test_dataloader': test_dataloader,
    'device':device
}
def train(model, params):
    loss_function=params["loss_function"]
    train_dataloader=params["train_dataloader"]
    test_dataloader=params["test_dataloader"]
    device=params["device"]

    for epoch in range(0, num_epochs):
      for i, data in enumerate(train_dataloader, 0):
        # train dataloader 로 불러온 데이터에서 이미지와 라벨을 분리
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 이전 batch에서 계산된 가중치를 초기화
        optimizer.zero_grad() 

        # forward + back propagation 연산
        outputs = model(inputs)
        train_loss = loss_function(outputs, labels)
        train_loss.backward()
        optimizer.step()

      # test accuracy 계산
      total = 0
      correct = 0
      accuracy = []
      for i, data in enumerate(test_dataloader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 결과값 연산
        outputs = model(inputs)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        test_loss = loss_function(outputs, labels).item()
        accuracy.append(100 * correct/total)

      # 학습 결과 출력
      print('Epoch: %d/%d, Train loss: %.6f, Test loss: %.6f, Accuracy: %.2f' %(epoch+1, num_epochs, train_loss.item(), test_loss, 100*correct/total))
train(model, params)
Epoch: 0/5, Train loss: 0.061112, Test loss: 0.425641, Accuracy: 90.45
Epoch: 1/5, Train loss: 0.050574, Test loss: 0.275092, Accuracy: 91.41
Epoch: 2/5, Train loss: 0.042732, Test loss: 0.229009, Accuracy: 91.83
Epoch: 3/5, Train loss: 1.185855, Test loss: 0.299690, Accuracy: 90.94
Epoch: 4/5, Train loss: 0.879423, Test loss: 0.334140, Accuracy: 90.74