3. GANยถ
2์ฅ์์๋ Victorain400 ๋ฐ์ดํฐ์ ์ ์ด๋ฏธ์ง๋ฅผ ํด๋๋ณ๋ก ์๊ฐํ ํด๋ณด์์ต๋๋ค. 3์ฅ์์๋ ๋ณธ๊ฒฉ์ ์ผ๋ก GAN์ ์ด์ฉํ์ฌ ํ๋ฐฑ ์ด๋ฏธ์ง๋ฅผ ์ปฌ๋ฌ ์ด๋ฏธ์ง๋ก ๋ณํ์ํค๋ ๋ชจ๋ธ์ ๊ตฌ์ถํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
3.1์ ์์๋ ๋ชจ๋ธ๋ง์ ์ฌ์ฉํ ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฌ์ค๊ณ , 3.2์ ์์๋ ๋ฐ์ดํฐ์ ํด๋์ค๋ฅผ ์ ์ํ์ฌ ๋ฐ์ดํฐ ๋ก๋๊น์ง ์ ์ํ๋๋ก ํ๊ฒ ์ต๋๋ค. 3.3์ ์์๋ Generator ํด๋์ค์ Discriminator ํด๋์ค๋ฅผ ์ ์ํ์ฌ ๋ชจ๋ธ์ ๊ตฌ์ถํด๋ณด๊ฒ ์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ 3.4์ ์์๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ค์ ํ์ฌ ๊ตฌ์ถ๋ ๋ชจ๋ธ์ ํ์ต์ํจ ํ, 3.5์ ์์ ํ ์คํธ ๋ฐ์ดํฐ์ ๋ํ ์์ธก๊ณผ ์ฑ๋ฅํ๊ฐ๋ฅผ ํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
3.1 ๋ฐ์ดํฐ์ ๋ค์ด๋ก๋ยถ
2.1์ ์์ ๋์จ ์ฝ๋๋ฅผ ํ์ฉํ์ฌ ๋ชจ๋ธ์ ์ฌ์ฉํ ๋ฐ์ดํฐ์
์ ๋ด๋ ค๋ฐ๋๋ก ํ๊ฒ ์ต๋๋ค. ๊ฐ์ง์ฐ๊ตฌ์ ๊นํ๋ธ์ Tutorial-Book-Utils๋ฅผ clone
ํ๊ณ PL_data_loader.py ํ์ผ๋ก Victorian400 ๋ฐ์ดํฐ์
์ ๋ด๋ ค๋ฐ์ ์์ถ์ ํธ๋ ์์์
๋๋ค.
!git clone https://github.com/Pseudo-Lab/Tutorial-Book-Utils
!python Tutorial-Book-Utils/PL_data_loader.py --data GAN-Colorization
!unzip -q Victorian400-GAN-colorization-data.zip
์ด๋ฒ์๋ 3์ฅ์์ ์ฌ์ฉํ ํจํค์ง๋ค์ ๋ถ๋ฌ์ค๊ฒ ์ต๋๋ค. os
์ glob
๋ ํ์ผ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ ์ ์๋ ํจํค์ง์ด๋ฉฐ, datetime
์ ๋ ์ง, ์๊ฐ์ ๊ณ์ฐํ ์ ์๋ ํจํค์ง์
๋๋ค. ๊ทธ๋ฆฌ๊ณ numpy
๋ ์์น ์ฐ์ฐ์ ์ฌ์ฉ๋๋ ์ ํ ๋์ ํจํค์ง์ด๋ฉฐ, matplotlib
๊ณผ PIL
์ Image
๋ ์ด๋ฏธ์ง ํ์ผ์ ์๊ฐํํ๋๋ฐ ์ฌ์ฉํ๋ ํจํค์ง์
๋๋ค. ๊ทธ ์ธ์ torch
์ torchvision
ํจํค์ง๋ ๋ฐ์ดํฐ์
์ ์์ ๋ชจ๋ธ ๊ตฌ์ถ์ ์ฌ์ฉ๋๋ Torch ๊ธฐ๋ฐ์ ํจํค์ง๋ค์
๋๋ค.
import os
import glob
import datetime
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
3.2 ๋ฐ์ดํฐ์ ํด๋์ค ์ ์ยถ
ํ์ดํ ์น ๋ชจ๋ธ์ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ๊ธฐ ์ํด์ ๋ฐ์ดํฐ์
ํด๋์ค๋ฅผ ์ ์ํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค. ํ๋ฐฑ, ์ปฌ๋ฌ ์ด๋ฏธ์ง์ ํด๋ ๊ฒฝ๋ก์ ์ด๋ฏธ์ง ๋ณํ(transforms)์ ์ธ์๋ก ๋ฐ์ Image.open
ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ํด๋น ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์ต๋๋ค. ํ๋ฐฑ ์ด๋ฏธ์ง๋ .convert("L")
์ ์ฌ์ฉํ์ฌ ๋จ์ผ ์ฑ๋๋ก, ์ปฌ๋ฌ ์ด๋ฏธ์ง๋ .convert("RGB")
์ ์ฌ์ฉํ์ฌ 3์ฑ๋๋ก ๋ฐ์์ค๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ด๋ฏธ์ง ๋ณํ์ ๊ฑฐ์ณ ๋์
๋๋ฆฌ ํํ๋ก ํ๋ฐฑ ์ด๋ฏธ์ง์ ์ปฌ๋ฌ ์ด๋ฏธ์ง๋ฅผ โAโ, โBโ์ ๊ฐ๊ฐ ๋ฐํ์์ผ์ค๋๋ค.
class VictorianDataset(Dataset):
def __init__(self, root, color_transforms_=None, gray_transforms_=None):
self.color_transforms = transforms.Compose(color_transforms_)
self.gray_transforms = transforms.Compose(gray_transforms_)
self.gray_files = sorted(glob.glob(os.path.join(root, 'gray') + "/*.*"))
self.color_files = sorted(glob.glob(os.path.join(root, 'resized') + "/*.*"))
def __getitem__(self, index):
gray_img = Image.open(self.gray_files[index % len(self.gray_files)]).convert("L")
color_img = Image.open(self.color_files[index % len(self.color_files)]).convert("RGB")
gray_img = self.gray_transforms(gray_img)
color_img = self.color_transforms(color_img)
return {"A": gray_img, "B": color_img}
def __len__(self):
return len(self.gray_files)
์ด๋ฏธ์ง ๊ฒฝ๋ก์ ์ฌ์ด์ฆ, ๊ฐ ๋ฐ์ดํฐ์ ์ ๋ฐฐ์น์ฌ์ด์ฆ๋ฅผ ์ง์ ํด์ค๋๋ค. ์ด๋ฏธ์ง ์ฌ์ด์ฆ์ ๋ฐฐ์น์ฌ์ด์ฆ๋ ๊ฐ์์ ํ๊ฒฝ(์ต๋ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋)์ ๋ง์ถฐ ์ค์ ํด์ฃผ๋ฉด ๋ฉ๋๋ค. ๋จ, ์ดํ์ ๋์ฌ ๋ชจ๋ธ ๊ตฌ์กฐ๊ฐ 256์ผ๋ก ๊ณ ์ ๋์ด ์์ด ๋ค๋ฅธ ์ด๋ฏธ์ง ์ฌ์ด์ฆ๋ฅผ ์ํ์ค ๊ฒฝ์ฐ ๋ชจ๋ธ ๊ตฌ์กฐ๋ ์ฝ๊ฐ์ ์์ ์ด ํ์ํฉ๋๋ค.
root = ''
test_root = root + 'test/'
img_height = 256
img_width = 256
batch_size = 12
test_batch_size = 6
gpu = 0
๋ฐ์ดํฐ์
ํด๋์ค์ ์ธ์๋ก ๋ฃ์ ์ด๋ฏธ์ง ๋ณํ(transform)์ ์ง์ ํฉ๋๋ค. ํํ ๋ฆฌ์ผ์์๋ ํ์ดํ ์น ๋ชจ๋ธ์ ๋ฃ๊ธฐ ์ํด tensor ํ์
์ผ๋ก ๋ฐ๊ฟ์ฃผ๊ณ 2.4์ ์์ ๊ตฌํ ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ก normalize
๋ฅผ ํด์ฃผ๋๋ก ํ๊ฒ ์ต๋๋ค. ์ด ์ธ์๋ ๋ชจ๋ธ ๊ตฌ์กฐ์ ๋ฐ๋ผ Resize
๋ฅผ ํด์ฃผ๊ฑฐ๋, ๋๋ฉ์ธ์ ๋ฐ๋ผ RandomCrop
(๋๋ค์ผ๋ก ์๋ฅด๊ธฐ), RandomVerticalFlip
(๋๋ค์ผ๋ก ์ํ ๋ค์ง๊ธฐ) ๋ฑ ๋ค์ํ ์ด๋ฏธ์ง ๋ณํ์ ํ ์ ์์ต๋๋ค.
color_mean = [0.58090717, 0.52688643, 0.45678478]
color_std = [0.25644188, 0.25482641, 0.24456465]
gray_mean = [0.5350533]
gray_std = [0.25051587]
color_transforms_ = [
transforms.Resize(size=(img_height, img_width)),
transforms.ToTensor(),
transforms.Normalize(mean=color_mean, std=color_std),
]
gray_transforms_ = [
transforms.Resize(size=(img_height, img_width)),
transforms.ToTensor(),
transforms.Normalize(mean=gray_mean, std=gray_std),
]
์์์ ์ ์ํ ๋ฐ์ดํฐ์
ํด๋์ค์ ์ด๋ฏธ์ง ๋ณํ(transform)์ DataLoader
ํจ์์ ๋ฃ์ด์ค๋๋ค.
train_loader = DataLoader(
VictorianDataset(root, color_transforms_=color_transforms_, gray_transforms_=gray_transforms_),
batch_size=batch_size,
shuffle=True
)
๋ฐ์ดํฐ ๋ก๋๊ฐ ์ ๋๋ก ๊ตฌ์ฑ๋์ด ์๋์ง ์์๋ณด๊ธฐ ์ํด ์๊ฐํํด๋ณด๊ฒ ์ต๋๋ค. ๋ฐ์ดํฐ ๋ก๋์๋ ์ ๊ทํ๋ ์ด๋ฏธ์ง๊ฐ ์ ์ฅ๋์ด ์๊ธฐ ๋๋ฌธ์ ์ ๊ทํ๋ฅผ ๋ณต์์ํค๊ณ ์๊ฐํํด์ฃผ์ด์ผ ํฉ๋๋ค. ๋ฐ๋ผ์ reNormalize
ํจ์๋ฅผ ๋ง๋ค์ด ๋ค์ ํ์คํธ์ฐจ๋ฅผ ๊ณฑํ๊ณ ํ๊ท ์ ๋ํด์ค๋๋ค. ์ด๋ .transpose()
๋ ์ถ์ ์์๋ฅผ ๋ฐ๊ฟ์ฃผ๊ณ .clip(min, max)
์ min๋ณด๋ค ์์ผ๋ฉด min์ผ๋ก, max๋ณด๋ค ํฌ๋ฉด max๋ก ๋ฐ๊ฟ์ฃผ๋ ์ญํ ์ ํฉ๋๋ค. ๋ํ, โAโ์ ํด๋นํ๋ ํ๋ฐฑ ์ด๋ฏธ์ง๋ ๋จ์ผ ์ฑ๋์ด๊ธฐ ๋๋ฌธ์ .reshape()
์ ํตํด 2์ฐจ์์ผ๋ก ๋ฐ๊ฟ์ฃผ๊ณ cmap=gray
๋ฅผ ์ค์ ํด์ฃผ์ด์ผ ์ด๋ฏธ์ง๋ฅผ ์ ๋๋ก ์ถ๋ ฅํ ์ ์์ต๋๋ค.
def reNormalize(img, mean, std):
img = img.numpy().transpose(1, 2, 0)
img = img * std + mean
img = img.clip(0, 1)
return img
fig = plt.figure(figsize=(10,5))
rows = 1
cols = 2
for X in train_loader:
print(X['A'].shape, X['B'].shape)
ax1 = fig.add_subplot(rows, cols, 1)
ax1.imshow(reNormalize(X["A"][0], gray_mean, gray_std).reshape(img_height, img_width), cmap='gray')
ax1.set_title('gray img')
ax2 = fig.add_subplot(rows, cols, 2)
ax2.imshow(reNormalize(X["B"][0], color_mean, color_std))
ax2.set_title('color img')
plt.show()
break
torch.Size([12, 1, 256, 256]) torch.Size([12, 3, 256, 256])
๋ฐ์ดํฐ ๋ก๋๊ฐ ์ ๋๋ก ๊ตฌ์ฑ๋์ด ํ๋ฐฑ ์ด๋ฏธ์ง์ ์ปฌ๋ฌ ์ด๋ฏธ์ง๊ฐ ์ ์ถ๋ ฅ๋๊ณ ์์ต๋๋ค. ๋ง์ฐฌ๊ฐ์ง๋ก ํ ์คํธ ๋ฐ์ดํฐ ๋ก๋๋ ์ ์ํ๊ณ ์๊ฐํํด๋ณด๊ฒ ์ต๋๋ค.
test_loader = DataLoader(
VictorianDataset(test_root, color_transforms_=color_transforms_, gray_transforms_=gray_transforms_),
batch_size=test_batch_size,
shuffle=False
)
fig = plt.figure(figsize=(10,5))
rows = 1
cols = 2
for X in test_loader:
print(X['A'].shape, X['B'].shape)
ax1 = fig.add_subplot(rows, cols, 1)
ax1.imshow(reNormalize(X["A"][0], gray_mean, gray_std).reshape(img_height, img_width), cmap='gray')
ax1.set_title('gray img')
ax2 = fig.add_subplot(rows, cols, 2)
ax2.imshow(reNormalize(X["B"][0], color_mean, color_std))
ax2.set_title('color img')
plt.show()
break
torch.Size([6, 1, 256, 256]) torch.Size([6, 3, 256, 256])
ํ ์คํธ ๋ฐ์ดํฐ ๋ก๋๋ ์ ๊ตฌ์ฑ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
3.3 ๋ชจ๋ธ ๊ตฌ์ถยถ
์ด๋ฒ์๋ ๋ฐ์ดํฐ์
์ ํ์ต์ํฌ ๋ชจ๋ธ์ ๊ตฌ์ถํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค. 3.3์ ์์๋ GAN ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํ๋ฐฑ ์ด๋ฏธ์ง๋ฅผ ์ปฌ๋ฌ ์ด๋ฏธ์ง๋ก ๋ณํ์์ผ๋ณผ ๊ฒ ์
๋๋ค. GAN ๋ชจ๋ธ์ ์์ฑ ๋ชจ๋ธ(Generator)๊ณผ ํ๋ณ ๋ชจ๋ธ(Discriminator)๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค. ๊ฐ๊ฐ์ Conv2d
์ ConvTranspose2d
, MaxPool2d
, ์ ๊ทํ, ํ์ฑํํจ์๋ค๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค.
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
"""noise image + gray image"""
self.conv1 = nn.Sequential(
nn.Conv2d(2, 64, 3, 1, 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.1)
)
self.maxpool = nn.MaxPool2d(2,2)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 64 * 2 , 3, 1, 1),
nn.BatchNorm2d(64 * 2),
nn.LeakyReLU(0.1)
)
self.upsample = nn.Sequential(
nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2)
)
self.conv1by1 = nn.Sequential(
nn.Conv2d(64,64,1,1,0),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.1)
)
self.conv = nn.Sequential(
nn.Conv2d(64, 3, 3, 1, 1),
nn.Tanh()
)
def forward(self, input):
output1 = self.conv1(input)
pool1 = self.maxpool(output1)
output2 = self.conv2(pool1)
output3 = self.upsample(output2) + output1
output4 = self.conv1by1(output3)
out = self.conv(output4)
return out
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3,64,kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(64,128,kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(128, 128, kernel_size = 1, stride = 1, padding = 0, bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(256, 512, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(512, 1024, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace = True),
)
self.fc = nn.Sequential(
nn.Linear(1024 * 8 * 8 , 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, batch_size),
nn.Sigmoid()
)
def forward(self, input, b_size):
output = self.main(input)
output = self.fc(output.view(b_size,-1))
return output
ํ์ต์ ๋ฐฐ์น๋ง๋ค ๊ฐ์ค์น๋ฅผ ์ด๊ธฐํ ์์ผ์ฃผ์ด์ผ ํฉ๋๋ค. ๋ฐ๋ผ์ weights_init
ํจ์๋ฅผ ๋ง๋ค์ด ์์ฑ ๋ชจ๋ธ(Generator)๊ณผ ํ๋ณ ๋ชจ๋ธ(Discriminator)์ ๋ถ์ฌ์ค๋๋ค.
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
๊ทธ๋ฆฌ๊ณ ์์ฑ ๋ชจ๋ธ(Generator)๊ณผ ํ๋ณ ๋ชจ๋ธ(Discriminator)์ GPU์์ ํ์ต์ํค๊ธฐ ์ํด .cuda(gpu)
๋ฅผ ํตํด ๋ชจ๋ธ์ ์ ๋ฌํด์ค๋๋ค. ์ด ๋, gpu
์๋ ์์์ ์ง์ ํ device๊ฐ ๋ค์ด๊ฐ๋๋ค. ์ถ๋ ฅ๋๋ ๊ฒ์ ํตํด ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
Gener = Generator().cuda(gpu) if torch.cuda.is_available() else Generator()
Gener.apply(weights_init)
Generator(
(conv1): Sequential(
(0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.1)
)
(maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.1)
)
(upsample): Sequential(
(0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(conv1by1): Sequential(
(0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.1)
)
(conv): Sequential(
(0): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Tanh()
)
)
Discri = Discriminator().cuda(gpu) if torch.cuda.is_available() else Discriminator()
Discri.apply(weights_init)
Discriminator(
(main): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(13): LeakyReLU(negative_slope=0.2, inplace=True)
(14): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(15): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(16): LeakyReLU(negative_slope=0.2, inplace=True)
)
(fc): Sequential(
(0): Linear(in_features=65536, out_features=1024, bias=True)
(1): LeakyReLU(negative_slope=0.2)
(2): Linear(in_features=1024, out_features=256, bias=True)
(3): LeakyReLU(negative_slope=0.2)
(4): Linear(in_features=256, out_features=12, bias=True)
(5): Sigmoid()
)
)
3.4 ๋ชจ๋ธ ํ์ตยถ
์ด์ ๊ตฌ์ถ๋ ๋ชจ๋ธ์ ํ์ต์์ผ๋ณด๊ฒ ์ต๋๋ค. ํ์ต ํ์๋ 50์ผ๋ก ์ง์ ํ๊ณ , ์์ฑ๋ชจ๋ธ๊ณผ ํ๋ณ๋ชจ๋ธ ๋ชจ๋ Adam ์ต์ ํ ๋ฐฉ๋ฒ์ ์ฌ์ฉํด๋ณด๊ฒ ์ต๋๋ค. ๊ฐ ํ์ดํผํ๋ผ๋ฏธํฐ๋ ํ๋ํ์ฌ ๋ค์ํ ์คํ์ ํด๋ณผ ์ ์์ต๋๋ค.
ํ์ต์ ์์ฑ๋ชจ๋ธ, ํ๋ณ๋ชจ๋ธ ์์ผ๋ก ์งํ๋ฉ๋๋ค. ์์ฑ๋ชจ๋ธ์์๋ ํ๋ณ๋ชจ๋ธ์ ์์ผ ์ ์๋ ์ปฌ๋ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋๋ก ํ์ตํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ํ๋ณ๋ชจ๋ธ์์๋ ์ปฌ๋ฌ ์ด๋ฏธ์ง๋ฅผ ์ง์ง(real)๋ก, ํ๋ฐฑ ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ง(fake)๋ก ์ธ์ํ๋๋ก ํ์ตํฉ๋๋ค.
max_epoch = 50
optimizerD = torch.optim.Adam(Discri.parameters(), lr = 0.0002,betas = (0.5, 0.999))
optimizerG = torch.optim.Adam(Gener.parameters(), lr = 0.0002, betas = (0.5, 0.999))
for epoch in range(max_epoch):
start_time = datetime.datetime.now()
loss_D = 0.0
for i, data in enumerate(train_loader):
grays, color = data['A'], data['B']
b_size = len(data['A'])
######## ์์ฑ๋ชจ๋ธ(Generator) ํ์ต ########
######## Train Generator ########
noise = torch.randn(b_size, 1, img_height, img_width).uniform_(0,1)
gray_noise = Variable(torch.cat([grays,noise],dim=1).cuda(gpu))
fake_img = Gener(gray_noise)
output = Discri(fake_img,b_size)
g_loss = torch.mean((output-1)**2)
######## ์ญ์ ํ & ์์ฑ๋ชจ๋ธ ์ต์ ํ ########
######## Backpropagation & Optimize G ########
Discri.zero_grad()
Gener.zero_grad()
g_loss.backward()
optimizerG.step()
######## ํ๋ณ๋ชจ๋ธ(Discriminator) ํ์ต ########
######## Train Discriminator ########
color = Variable(color.cuda(gpu))
noise = torch.randn(b_size, 1, img_height, img_width).uniform_(0,1)
gray_noise = Variable(torch.cat([grays,noise],dim=1).cuda(gpu))
######## ํ๋ณ๋ชจ๋ธ์ด ์ปฌ๋ฌ ์ด๋ฏธ์ง๋ฅผ ์ง์ง(real)๋ก ์ธ์ํ๋๋ก ํ์ต ########
######## Train d to recognize color image as real ########
output = Discri(color,b_size)
real_loss = torch.mean((output-1)**2)
######## ํ๋ณ๋ชจ๋ธ์ด ํ๋ฐฑ ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ง(fake)๋ก ์ธ์ํ๋๋ก ํ์ต ########
######## Train d to recognize fake image as fake ########
fake_img = Gener(gray_noise)
output = Discri(fake_img,b_size)
fake_loss = torch.mean(output**2)
######## ์ญ์ ํ & ํ๋ณ๋ชจ๋ธ ์ต์ ํ ########
######## Backpropagation & Optimize D ########
d_loss = real_loss + fake_loss
Discri.zero_grad()
Gener.zero_grad()
d_loss.backward()
optimizerD.step()
fake_img = torchvision.utils.make_grid(fake_img.data)
epoch_time = datetime.datetime.now() - start_time
if (epoch + 1) % 5 == 0:
torch.save(Gener.state_dict(), "generator_%d.pth" % (epoch+1))
torch.save(Discri.state_dict(), "discriminator_%d.pth" % (epoch+1))
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s" % (epoch+1, max_epoch, i+1, len(train_loader), d_loss.item(), g_loss.item(), epoch_time))
plt.imshow(reNormalize(fake_img.cpu(), color_mean, color_std))
plt.show()
[Epoch 5/50] [Batch 34/34] [D loss: 0.335707] [G loss: 0.952135] ETA: 0:00:15.254198
[Epoch 10/50] [Batch 34/34] [D loss: 0.232681] [G loss: 0.925357] ETA: 0:00:16.087967
[Epoch 15/50] [Batch 34/34] [D loss: 0.222960] [G loss: 0.742795] ETA: 0:00:14.794427
[Epoch 20/50] [Batch 34/34] [D loss: 0.002736] [G loss: 0.968330] ETA: 0:00:15.417760
[Epoch 25/50] [Batch 34/34] [D loss: 0.000585] [G loss: 0.999990] ETA: 0:00:15.662105
[Epoch 30/50] [Batch 34/34] [D loss: 0.008619] [G loss: 0.997328] ETA: 0:00:15.388837
[Epoch 35/50] [Batch 34/34] [D loss: 0.201541] [G loss: 0.998293] ETA: 0:00:15.455658
[Epoch 40/50] [Batch 34/34] [D loss: 0.011544] [G loss: 0.999389] ETA: 0:00:15.419755
[Epoch 45/50] [Batch 34/34] [D loss: 0.000001] [G loss: 0.999999] ETA: 0:00:15.365898
[Epoch 50/50] [Batch 34/34] [D loss: 0.158762] [G loss: 0.997427] ETA: 0:00:15.460645
5์ํญ๋ง๋ค ์์ฑ๋ชจ๋ธ๊ณผ ํ๋ณ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ์ ์ฅํ์ฌ ๋์ค์ ํ์ต๋ ๋ชจ๋ธ์ ์ฌ๊ตฌํํ ์ ์๋๋ก ํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ๊ฐ ๋ชจ๋ธ์ ๋ก์ค๊ฐ๊ณผ ์์ฑ๋ชจ๋ธ์์ ์์ฑํ ์ด๋ฏธ์ง๋ฅผ ์ถ๋ ฅํ์ฌ ์ด๋ป๊ฒ ํ์ต๋๊ณ ์๋์ง ํ์ธํฉ๋๋ค.
์์ฑ๋ชจ๋ธ์ ๋ก์ค์ ํ๋ณ๋ชจ๋ธ์ ๋ก์ค์ ๊ด๊ณ๋ฅผ ํ์ธํ์ฌ ์ ์ ํ ์ํญ ์๋ฅผ ๊ฒฐ์ ํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ๊ทธ๋ ์ง ๋ชปํ ๊ฒฝ์ฐ, ํ ๋ชจ๋ธ์ ์ค๋ฒํผํ ๋์ด ๋ ์ด์์ ํ์ต์ด ๋ฌด์๋ฏธํด์ง๊ฒ ๋ฉ๋๋ค.
3.5 ์์ธก ๋ฐ ์ฑ๋ฅ ํ๊ฐยถ
ํ์ต์ด ๋๋ฌ์ผ๋ ์๋ก์ด ํ๋ฐฑ ์ด๋ฏธ์ง์ ์ ์ฉํ์ฌ ์ปฌ๋ฌ ์ด๋ฏธ์ง๋ก ์์ธก(์ฑ์)ํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค. ์์์ ํ์ตํ ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์ค๊ณ ํ ์คํธ ๋ฐ์ดํฐ ๋ก๋์ ์ด๋ฏธ์ง๋ฅผ ๋ฃ์ด ์์ธกํ๊ฒ ๋ฉ๋๋ค. ํํ ๋ฆฌ์ผ์์๋ 35์ํญ์ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์ค๋๋ก ํ๊ฒ ์ต๋๋ค. ์ด์ฒ๋ผ ์ถ๋ ฅ๋๋ ์ด๋ฏธ์ง๋ฅผ ๋ณด๋ฉด์ ์ํ์๋ ์ํญ์ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์๋ ๋ฉ๋๋ค.
Gener.load_state_dict(torch.load("generator_35.pth" ))
Discri.load_state_dict(torch.load("discriminator_35.pth" ))
<All keys matched successfully>
Discri.eval()
Gener.eval()
fixed_noise = torch.randn(test_batch_size, 1, img_height, img_width).uniform_(0,1)
for i, data in enumerate(test_loader,0) :
images, label = data['A'], data['B']
if len(data['A']) != test_batch_size:
continue
grays = torch.from_numpy(np.resize(images.numpy(), (test_batch_size, 1, img_height, img_width)))
gray = Variable(torch.cat([grays,fixed_noise],dim = 1).cuda(gpu))
output = Gener(gray)
inputs = torchvision.utils.make_grid(grays)
labels = torchvision.utils.make_grid(label)
out = torchvision.utils.make_grid(output.data)
print('==================input==================')
plt.imshow(reNormalize(inputs.cpu(), gray_mean, gray_std))
plt.show()
print('==================target==================')
plt.imshow(reNormalize(labels.cpu(), color_mean, color_std))
plt.show()
print('==================output==================')
plt.imshow(reNormalize(out.cpu(), color_mean, color_std))
plt.show()
==================input==================
==================target==================
==================output==================
๊ฒฐ๊ณผ๋ ํ๋ฐฑ(input), ์ปฌ๋ฌ(target), ์์ฑ๋ ์ด๋ฏธ์ง(output)์์ผ๋ก ์ถ๋ ฅ๋์ด์ง๋๋ค. ์ ๋ช ํ์ง ์์ง๋ง ๋น์ทํ ์๊ฐ์ด ์ด๋ ์ ๋ ์ฑ์๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
์ง๊ธ๊น์ง 3์ฅ์์ ํ๋ฐฑ ์ด๋ฏธ์ง๋ฅผ ์ปฌ๋ฌ ์ด๋ฏธ์ง๋ก ๋ฐ๊พธ๋ ์์ฑ๋ชจ๋ธ GAN์ ๊ตฌ์ถํด๋ณด์์ต๋๋ค. ๋ค์ ์ฅ์์๋ ์ฑ์์ ๋ ํนํ๋์ด์๋ pix2pix๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ๋น๊ตํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.