PyTorch 로 Transfer-Learning 하기¶
이전 챕터에서 pytorch 로 resnet 구현과 관련한 내용을 다루었습니다. 이번 노트북에서는 pytorch 로 resnet 모델을 학습하는 방법에 대해 살펴보겠습니다.
담당자: 권지현 님
최종수정일: 21-09-29
본 자료는 가짜연구소 3기 Pytorch guide 크루 활동으로 작성됨
01 data load¶
본 노트북에서는 torchvision
에서 제공하는 데이터 셋을 활용합니다. torchvision
에 대한 설명은 링크 를 참조바랍니다.
데이터셋을 활용하기 위한 라이브러리를 import 하겠습니다.
# torchvision 관련 라이브러리 import
from torchvision import utils
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
사용할 데이터 셋은 STL10
입니다. STL10
은 Image Classification 의 벤치마크로 10개의 라벨을 가진 데이터 셋 입니다. torchvisvion
에서는 5000개의 train 데이터와 8000개의 test 로 구성되어 있으며, datasets.STL10
매소드로 다운받을 수 있습니다.
경로를 설정한 후 train, test 데이터를 다운받습니다. 경로는 단순히 root 에 폴더를 생성하여 지정하였습니다.
transforms 은 ToTensor()
로 설정합니다. transforms 에 대한 설명은 링크 를 참조 바랍니다.
import os
os.mkdir('./train')
os.mkdir('./test')
train_dataset = datasets.STL10('/train', split='train', download=True, transform=transforms.ToTensor())
test_dataset = datasets.STL10('/test', split='test', download=True, transform=transforms.ToTensor())
Files already downloaded and verified
Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to /test/stl10_binary.tar.gz
Extracting /test/stl10_binary.tar.gz to /test
다운받은 이미지에 대해 스케일링 과정이 필요합니다.
transform
을 활용하여 이미지 크기를 고정하고, normalization
을 진행합니다.
주어진 데이터셋의 이미지는 RGB 3개의 채널로 구성되어 있으므로, 우선 채널 별 mean 값과 std 값을 계산한 후 transform 을 정의합니다.
import numpy as np
# 채널 별 mean 계산
def get_mean(dataset):
meanRGB = [np.mean(image.numpy(), axis=(1,2)) for image,_ in dataset]
meanR = np.mean([m[0] for m in meanRGB])
meanG = np.mean([m[1] for m in meanRGB])
meanB = np.mean([m[2] for m in meanRGB])
return [meanR, meanG, meanB]
# 채널 별 str 계산
def get_std(dataset):
stdRGB = [np.std(image.numpy(), axis=(1,2)) for image,_ in dataset]
stdR = np.mean([s[0] for s in stdRGB])
stdG = np.mean([s[1] for s in stdRGB])
stdB = np.mean([s[2] for s in stdRGB])
return [stdR, stdG, stdB]
transforms.Compose 매소드로 trainsform 단계를 묶어서 진행할 수 있습니다. 본 노트북에서는 이미지의 크기를 임의로 128로 고정한 후, 정규화하는 과정만 진행하겠습니다. 보다 다양한 augmentation 방법은 링크 를 참조바랍니다.
train_transforms = transforms.Compose([transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(get_mean(train_dataset), get_std(train_dataset))])
test_transforms = transforms.Compose([transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(get_mean(test_dataset), get_std(test_dataset))])
# trainsform 정의
train_dataset.transform = train_transforms
test_dataset.transform = test_transforms
이후 dataloader 를 정의합니다. batch size 는 임의로 128, 64로 설정해두었습니다.
# dataloader 정의
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
데이터 사용을 위한 준비가 완료되었습니다.
02 training¶
모델 학습을 위해 pretrained 된 resnet 50 모델을 사용하겠습니다. 해당 resnet 모델은 사전 학습된 모델로, 이미지 분류 문제를 해결할 수 있도록 규모가 큰 데이터(ImageNet)로 미리 학습된 모델을 의미합니다.
torchvision model 에서 구현된 resnet의 구조는 이전 챕터에서 다루었습니다. 관련 내용은 링크 를 참조 바랍니다.
from torchvision import models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 학습 환경 설정
model = models.resnet50(pretrained=True).to(device) # true 옵션으로 사전 학습된 모델을 로드
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
torchsummary의 summary 매소드로 모델을 요약하여 확인할 수 있습니다. 사용할 데이터 셋은 128*128 크기의 RGB 3개의 chennel로 구성되어 있습니다. 모델의 layer 별 파라미터 개수는 다음과 같습니다.
from torchsummary import summary
summary(model, (3, 128, 128))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 64, 64] 9,408
BatchNorm2d-2 [-1, 64, 64, 64] 128
ReLU-3 [-1, 64, 64, 64] 0
MaxPool2d-4 [-1, 64, 32, 32] 0
Conv2d-5 [-1, 64, 32, 32] 4,096
BatchNorm2d-6 [-1, 64, 32, 32] 128
ReLU-7 [-1, 64, 32, 32] 0
Conv2d-8 [-1, 64, 32, 32] 36,864
BatchNorm2d-9 [-1, 64, 32, 32] 128
ReLU-10 [-1, 64, 32, 32] 0
Conv2d-11 [-1, 256, 32, 32] 16,384
BatchNorm2d-12 [-1, 256, 32, 32] 512
Conv2d-13 [-1, 256, 32, 32] 16,384
BatchNorm2d-14 [-1, 256, 32, 32] 512
ReLU-15 [-1, 256, 32, 32] 0
Bottleneck-16 [-1, 256, 32, 32] 0
Conv2d-17 [-1, 64, 32, 32] 16,384
BatchNorm2d-18 [-1, 64, 32, 32] 128
ReLU-19 [-1, 64, 32, 32] 0
Conv2d-20 [-1, 64, 32, 32] 36,864
BatchNorm2d-21 [-1, 64, 32, 32] 128
ReLU-22 [-1, 64, 32, 32] 0
Conv2d-23 [-1, 256, 32, 32] 16,384
BatchNorm2d-24 [-1, 256, 32, 32] 512
ReLU-25 [-1, 256, 32, 32] 0
Bottleneck-26 [-1, 256, 32, 32] 0
Conv2d-27 [-1, 64, 32, 32] 16,384
BatchNorm2d-28 [-1, 64, 32, 32] 128
ReLU-29 [-1, 64, 32, 32] 0
Conv2d-30 [-1, 64, 32, 32] 36,864
BatchNorm2d-31 [-1, 64, 32, 32] 128
ReLU-32 [-1, 64, 32, 32] 0
Conv2d-33 [-1, 256, 32, 32] 16,384
BatchNorm2d-34 [-1, 256, 32, 32] 512
ReLU-35 [-1, 256, 32, 32] 0
Bottleneck-36 [-1, 256, 32, 32] 0
Conv2d-37 [-1, 128, 32, 32] 32,768
BatchNorm2d-38 [-1, 128, 32, 32] 256
ReLU-39 [-1, 128, 32, 32] 0
Conv2d-40 [-1, 128, 16, 16] 147,456
BatchNorm2d-41 [-1, 128, 16, 16] 256
ReLU-42 [-1, 128, 16, 16] 0
Conv2d-43 [-1, 512, 16, 16] 65,536
BatchNorm2d-44 [-1, 512, 16, 16] 1,024
Conv2d-45 [-1, 512, 16, 16] 131,072
BatchNorm2d-46 [-1, 512, 16, 16] 1,024
ReLU-47 [-1, 512, 16, 16] 0
Bottleneck-48 [-1, 512, 16, 16] 0
Conv2d-49 [-1, 128, 16, 16] 65,536
BatchNorm2d-50 [-1, 128, 16, 16] 256
ReLU-51 [-1, 128, 16, 16] 0
Conv2d-52 [-1, 128, 16, 16] 147,456
BatchNorm2d-53 [-1, 128, 16, 16] 256
ReLU-54 [-1, 128, 16, 16] 0
Conv2d-55 [-1, 512, 16, 16] 65,536
BatchNorm2d-56 [-1, 512, 16, 16] 1,024
ReLU-57 [-1, 512, 16, 16] 0
Bottleneck-58 [-1, 512, 16, 16] 0
Conv2d-59 [-1, 128, 16, 16] 65,536
BatchNorm2d-60 [-1, 128, 16, 16] 256
ReLU-61 [-1, 128, 16, 16] 0
Conv2d-62 [-1, 128, 16, 16] 147,456
BatchNorm2d-63 [-1, 128, 16, 16] 256
ReLU-64 [-1, 128, 16, 16] 0
Conv2d-65 [-1, 512, 16, 16] 65,536
BatchNorm2d-66 [-1, 512, 16, 16] 1,024
ReLU-67 [-1, 512, 16, 16] 0
Bottleneck-68 [-1, 512, 16, 16] 0
Conv2d-69 [-1, 128, 16, 16] 65,536
BatchNorm2d-70 [-1, 128, 16, 16] 256
ReLU-71 [-1, 128, 16, 16] 0
Conv2d-72 [-1, 128, 16, 16] 147,456
BatchNorm2d-73 [-1, 128, 16, 16] 256
ReLU-74 [-1, 128, 16, 16] 0
Conv2d-75 [-1, 512, 16, 16] 65,536
BatchNorm2d-76 [-1, 512, 16, 16] 1,024
ReLU-77 [-1, 512, 16, 16] 0
Bottleneck-78 [-1, 512, 16, 16] 0
Conv2d-79 [-1, 256, 16, 16] 131,072
BatchNorm2d-80 [-1, 256, 16, 16] 512
ReLU-81 [-1, 256, 16, 16] 0
Conv2d-82 [-1, 256, 8, 8] 589,824
BatchNorm2d-83 [-1, 256, 8, 8] 512
ReLU-84 [-1, 256, 8, 8] 0
Conv2d-85 [-1, 1024, 8, 8] 262,144
BatchNorm2d-86 [-1, 1024, 8, 8] 2,048
Conv2d-87 [-1, 1024, 8, 8] 524,288
BatchNorm2d-88 [-1, 1024, 8, 8] 2,048
ReLU-89 [-1, 1024, 8, 8] 0
Bottleneck-90 [-1, 1024, 8, 8] 0
Conv2d-91 [-1, 256, 8, 8] 262,144
BatchNorm2d-92 [-1, 256, 8, 8] 512
ReLU-93 [-1, 256, 8, 8] 0
Conv2d-94 [-1, 256, 8, 8] 589,824
BatchNorm2d-95 [-1, 256, 8, 8] 512
ReLU-96 [-1, 256, 8, 8] 0
Conv2d-97 [-1, 1024, 8, 8] 262,144
BatchNorm2d-98 [-1, 1024, 8, 8] 2,048
ReLU-99 [-1, 1024, 8, 8] 0
Bottleneck-100 [-1, 1024, 8, 8] 0
Conv2d-101 [-1, 256, 8, 8] 262,144
BatchNorm2d-102 [-1, 256, 8, 8] 512
ReLU-103 [-1, 256, 8, 8] 0
Conv2d-104 [-1, 256, 8, 8] 589,824
BatchNorm2d-105 [-1, 256, 8, 8] 512
ReLU-106 [-1, 256, 8, 8] 0
Conv2d-107 [-1, 1024, 8, 8] 262,144
BatchNorm2d-108 [-1, 1024, 8, 8] 2,048
ReLU-109 [-1, 1024, 8, 8] 0
Bottleneck-110 [-1, 1024, 8, 8] 0
Conv2d-111 [-1, 256, 8, 8] 262,144
BatchNorm2d-112 [-1, 256, 8, 8] 512
ReLU-113 [-1, 256, 8, 8] 0
Conv2d-114 [-1, 256, 8, 8] 589,824
BatchNorm2d-115 [-1, 256, 8, 8] 512
ReLU-116 [-1, 256, 8, 8] 0
Conv2d-117 [-1, 1024, 8, 8] 262,144
BatchNorm2d-118 [-1, 1024, 8, 8] 2,048
ReLU-119 [-1, 1024, 8, 8] 0
Bottleneck-120 [-1, 1024, 8, 8] 0
Conv2d-121 [-1, 256, 8, 8] 262,144
BatchNorm2d-122 [-1, 256, 8, 8] 512
ReLU-123 [-1, 256, 8, 8] 0
Conv2d-124 [-1, 256, 8, 8] 589,824
BatchNorm2d-125 [-1, 256, 8, 8] 512
ReLU-126 [-1, 256, 8, 8] 0
Conv2d-127 [-1, 1024, 8, 8] 262,144
BatchNorm2d-128 [-1, 1024, 8, 8] 2,048
ReLU-129 [-1, 1024, 8, 8] 0
Bottleneck-130 [-1, 1024, 8, 8] 0
Conv2d-131 [-1, 256, 8, 8] 262,144
BatchNorm2d-132 [-1, 256, 8, 8] 512
ReLU-133 [-1, 256, 8, 8] 0
Conv2d-134 [-1, 256, 8, 8] 589,824
BatchNorm2d-135 [-1, 256, 8, 8] 512
ReLU-136 [-1, 256, 8, 8] 0
Conv2d-137 [-1, 1024, 8, 8] 262,144
BatchNorm2d-138 [-1, 1024, 8, 8] 2,048
ReLU-139 [-1, 1024, 8, 8] 0
Bottleneck-140 [-1, 1024, 8, 8] 0
Conv2d-141 [-1, 512, 8, 8] 524,288
BatchNorm2d-142 [-1, 512, 8, 8] 1,024
ReLU-143 [-1, 512, 8, 8] 0
Conv2d-144 [-1, 512, 4, 4] 2,359,296
BatchNorm2d-145 [-1, 512, 4, 4] 1,024
ReLU-146 [-1, 512, 4, 4] 0
Conv2d-147 [-1, 2048, 4, 4] 1,048,576
BatchNorm2d-148 [-1, 2048, 4, 4] 4,096
Conv2d-149 [-1, 2048, 4, 4] 2,097,152
BatchNorm2d-150 [-1, 2048, 4, 4] 4,096
ReLU-151 [-1, 2048, 4, 4] 0
Bottleneck-152 [-1, 2048, 4, 4] 0
Conv2d-153 [-1, 512, 4, 4] 1,048,576
BatchNorm2d-154 [-1, 512, 4, 4] 1,024
ReLU-155 [-1, 512, 4, 4] 0
Conv2d-156 [-1, 512, 4, 4] 2,359,296
BatchNorm2d-157 [-1, 512, 4, 4] 1,024
ReLU-158 [-1, 512, 4, 4] 0
Conv2d-159 [-1, 2048, 4, 4] 1,048,576
BatchNorm2d-160 [-1, 2048, 4, 4] 4,096
ReLU-161 [-1, 2048, 4, 4] 0
Bottleneck-162 [-1, 2048, 4, 4] 0
Conv2d-163 [-1, 512, 4, 4] 1,048,576
BatchNorm2d-164 [-1, 512, 4, 4] 1,024
ReLU-165 [-1, 512, 4, 4] 0
Conv2d-166 [-1, 512, 4, 4] 2,359,296
BatchNorm2d-167 [-1, 512, 4, 4] 1,024
ReLU-168 [-1, 512, 4, 4] 0
Conv2d-169 [-1, 2048, 4, 4] 1,048,576
BatchNorm2d-170 [-1, 2048, 4, 4] 4,096
ReLU-171 [-1, 2048, 4, 4] 0
Bottleneck-172 [-1, 2048, 4, 4] 0
AdaptiveAvgPool2d-173 [-1, 2048, 1, 1] 0
Linear-174 [-1, 1000] 2,049,000
================================================================
Total params: 25,557,032
Trainable params: 25,557,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 93.59
Params size (MB): 97.49
Estimated Total Size (MB): 191.27
----------------------------------------------------------------
191 mb 정도의 가벼운 모델 구조를 확인할 수 있습니다.
다음은 모델 학습을 위한 함수를 정의하겠습니다. 가장 간단한 형태의 train 컨테이너를 구성하였습니다.
필요한 패키지를 임포트합니다.
import torch
import torch.nn as nn
from torch import optim
전이학습을 위해 필요한 파라미터를 정의합니다.
lr
은 learning rate 로 0.0001 로 설정하였습니다. 많은 epoch 의 학습을 진행할 때에는 스케쥴러를 사용하면 용이하나, 이번 노트북에서는 가장 간단한 형태의 train 컨테이너를 구성하여 다루지 않았습니다.
optimizer
은 Adam 으로, loss function
은 cross entropy 로 정의합니다. loss function 은 학습 목적에 따라 다양하게 구성할 수 있으므로 참고바랍니다.
학습 횟수인 num_epochs
은 5로 설정하였습니다.
lr = 0.0001
num_epochs = 5
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_function = nn.CrossEntropyLoss().to(device)
이때 설정한 파라미터들을 (수정 예정)
params = {
'num_epochs':num_epochs,
'optimizer':optimizer,
'loss_function':loss_function,
'train_dataloader':train_dataloader,
'test_dataloader': test_dataloader,
'device':device
}
def train(model, params):ㄴ
loss_function=params["loss_function"]
train_dataloader=params["train_dataloader"]
test_dataloader=params["test_dataloader"]
device=params["device"]
for epoch in range(0, num_epochs):
for i, data in enumerate(train_dataloader, 0):
# train dataloader 로 불러온 데이터에서 이미지와 라벨을 분리
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
# 이전 batch에서 계산된 가중치를 초기화
optimizer.zero_grad()
# forward + back propagation 연산
outputs = model(inputs)
train_loss = loss_function(outputs, labels)
train_loss.backward()
optimizer.step()
# test accuracy 계산
total = 0
correct = 0
accuracy = []
for i, data in enumerate(test_dataloader, 0):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
# 결과값 연산
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_loss = loss_function(outputs, labels).item()
accuracy.append(100 * correct/total)
# 학습 결과 출력
print('Epoch: %d/%d, Train loss: %.6f, Test loss: %.6f, Accuracy: %.2f' %(epoch+1, num_epochs, train_loss.item(), test_loss, 100*correct/total))
train(model, params)
Epoch: 0/5, Train loss: 0.061112, Test loss: 0.425641, Accuracy: 90.45
Epoch: 1/5, Train loss: 0.050574, Test loss: 0.275092, Accuracy: 91.41
Epoch: 2/5, Train loss: 0.042732, Test loss: 0.229009, Accuracy: 91.83
Epoch: 3/5, Train loss: 1.185855, Test loss: 0.299690, Accuracy: 90.94
Epoch: 4/5, Train loss: 0.879423, Test loss: 0.334140, Accuracy: 90.74