4. pix2pixΒΆ

Open In Colab

이전 μž₯μ—μ„œλŠ” GAN λͺ¨λΈμ„ μ΄μš©ν•˜μ—¬, 흑백 이미지λ₯Ό 컬러 μ΄λ―Έμ§€λ‘œ λ³€ν™˜ν•΄λ³΄μ•˜μŠ΅λ‹ˆλ‹€.

이번 μž₯μ—μ„œλŠ” cGAN (conditional Generative Adversarial Network) 기반인 pix2pix λͺ¨λΈκ³Ό 19μ„ΈκΈ° 일러슀트둜 이루어진 Victorian400 데이터셋을 μ΄μš©ν•˜μ—¬, ν•΄λ‹Ή λͺ¨λΈμ„ ν•™μŠ΅ν•˜κ³  색채λ₯Ό μž…νžˆλŠ” ν…ŒμŠ€νŠΈ 해보도둝 ν•˜κ² μŠ΅λ‹ˆλ‹€.

4.1 데이터셋 λ‹€μš΄λ‘œλ“œΒΆ

μš°μ„  Victorian400 데이터셋을 λ‚΄λ € 받도둝 ν•˜κ² μŠ΅λ‹ˆλ‹€. κ°€μ§œμ—°κ΅¬μ†Œμ—μ„œ μ œμž‘ν•œ νˆ΄μ„ 톡해 ν•΄λ‹Ή 데이터셋을 λ‚΄λ €λ°›κ³  압좕을 풀도둝 ν•˜κ² μŠ΅λ‹ˆλ‹€.

!git clone https://github.com/Pseudo-Lab/Tutorial-Book-Utils
!python Tutorial-Book-Utils/PL_data_loader.py --data GAN-Colorization
!unzip -q Victorian400-GAN-colorization-data.zip
Cloning into 'Tutorial-Book-Utils'...
remote: Enumerating objects: 27, done.
remote: Counting objects: 100% (27/27), done.
remote: Compressing objects: 100% (23/23), done.
remote: Total 27 (delta 7), reused 13 (delta 3), pack-reused 0
Unpacking objects: 100% (27/27), done.
Victorian400-GAN-colorization-data.zip is done!

기본적인 λͺ¨λ“ˆλ“€μ„ import ν•΄μ€λ‹ˆλ‹€.

import os
import glob
import numpy as np
import datetime
import matplotlib.pyplot as plt
from PIL import Image

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torch.autograd import Variable

4.2 데이터셋 클래슀 μ •μ˜ΒΆ

VictorianDataset ν΄λž˜μŠ€λŠ” 흑백사진(gray)κ³Ό μ»¬λŸ¬μ‚¬μ§„(resized)을 ν•¨κ»˜ 파일λͺ… μˆœμ„œλŒ€λ‘œ λΆˆλŸ¬μ˜€λŠ” __init__ ν•¨μˆ˜, 각각의 이미지 νŒŒμΌμ„ ν”½μ…€λ‘œ μ €μž₯ν•˜λŠ” __getitem__ ν•¨μˆ˜, 파일 갯수λ₯Ό λ°˜ν™˜ν•˜λŠ” __len__ ν•¨μˆ˜κ°€ μ§€μ •λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€.

class VictorianDataset(Dataset):
    def __init__(self, root, color_transforms_=None, gray_transforms_=None):

        self.color_transforms = transforms.Compose(color_transforms_)
        self.gray_transforms = transforms.Compose(gray_transforms_)
        self.gray_files = sorted(glob.glob(os.path.join(root, 'gray') + "/*.*"))
        self.color_files = sorted(glob.glob(os.path.join(root, 'resized') + "/*.*"))
     
    def __getitem__(self, index):
        gray_img = Image.open(self.gray_files[index % len(self.gray_files)]).convert("RGB")
        color_img = Image.open(self.color_files[index % len(self.color_files)]).convert("RGB")
    
        gray_img = self.gray_transforms(gray_img)
        color_img = self.color_transforms(color_img)

        return {"A": gray_img, "B": color_img}

    def __len__(self):
        return len(self.gray_files)

배치 μ‚¬μ΄μ¦ˆμ™€ 이미지 μ‚¬μ΄μ¦ˆλ₯Ό 미리 μ§€μ •ν•΄μ€λ‹ˆλ‹€. 폴더 μœ„μΉ˜λ₯Ό root둜 μ§€μ •ν•΄μ€λ‹ˆλ‹€. 이미지 μ‚¬μ΄μ¦ˆμ˜ 경우 높이와 κ°€λ‘œ λͺ¨λ‘ 256으둜 λ§žμΆ°μ€λ‹ˆλ‹€. pix2pix λͺ¨λΈμ˜ 경우 256 x 256 이미지 μ‚¬μ΄μ¦ˆλ₯Ό ν™œμš©ν•©λ‹ˆλ‹€. (μΆ”κ°€ ν•  것)

root = ''

batch_size = 12
img_height = 256
img_width = 256

transform.Normalizeμ—μ„œ Normalize 크기λ₯Ό μ§€μ •ν•΄μ€λ‹ˆλ‹€. 2.4μ ˆμ—μ„œ κ΅¬ν•œ 평균과 ν‘œμ€€νŽΈμ°¨λ‘œ normalizeλ₯Ό 해주도둝 ν•˜κ² μŠ΅λ‹ˆλ‹€.

color_mean = [0.58090717, 0.52688643, 0.45678478]
color_std = [0.25644188, 0.25482641, 0.24456465]
gray_mean = [0.5350533, 0.5350533, 0.5350533]
gray_std = [0.25051587, 0.25051587, 0.25051587]

color_transforms_ = [
    transforms.ToTensor(),
    transforms.Normalize(mean=color_mean, std=color_std),
]

gray_transforms_ = [
    transforms.ToTensor(),
    transforms.Normalize(mean=gray_mean, std=gray_std),
]
train_loader  = DataLoader(
    VictorianDataset(root, color_transforms_=color_transforms_, gray_transforms_=gray_transforms_),
    batch_size=batch_size,
    shuffle=True
)
def reNormalize(img, mean, std):
    img = img.numpy().transpose(1, 2, 0)
    img = img * std + mean
    img = img.clip(0, 1)
    return img

이제 뢈러온 데이터가 ν”½μ…€λ‘œ 잘 μ €μž₯이 λ˜μ—ˆλŠ”μ§€, μ‹œκ°ν™” 해보도둝 ν•˜κ² μŠ΅λ‹ˆλ‹€.

fig = plt.figure(figsize=(10, 5))
rows = 1 
cols = 2

for X in train_loader :

    print(X['A'].shape, X['B'].shape)
    ax1 = fig.add_subplot(rows, cols, 1)
    ax1.imshow(reNormalize(X["A"][0], gray_mean, gray_std)) 
    ax1.set_title('gray img')

    ax2 = fig.add_subplot(rows, cols, 2)
    ax2.imshow(reNormalize(X["B"][0], color_mean, color_std))
    ax2.set_title('color img')    

    plt.show()
    break
torch.Size([12, 3, 256, 256]) torch.Size([12, 3, 256, 256])
../../_images/Ch4-pix2pix_18_1.png

4.3 λͺ¨λΈ ꡬ좕¢

이제 pix2pix λͺ¨λΈμ„ μ„€κ³„ν•˜λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€. pix2pix의 νŠΉμ§•μ€ 일반적인 인코더-디코더(Encoder-Decoder)λ³΄λ‹€λŠ” U-NET을 μ‚¬μš©ν•©λ‹ˆλ‹€. U-NET의 νŠΉμ§•μ€ 일반적인 인코더-디코더와 달리 μŠ€ν‚΅ 컀λ„₯μ…˜ (Skip Connections)이 μžˆμ–΄, 인코더 λ ˆμ΄μ–΄μ™€ 디코더 λ ˆμ΄μ–΄ κ°„μ˜ 연결을 보닀 λ‘œμ»¬λΌμ΄μ§•(localization)을 잘 ν•΄μ£ΌλŠ” νŠΉμ§•μ΄ μžˆμŠ΅λ‹ˆλ‹€. 예λ₯Ό λ“€μ–΄, 첫 인코더 λ ˆμ΄μ–΄ 크기가 256 x 256 x 3이라면, λ§ˆμ§€λ§‰ 디코더 λ ˆμ΄μ–΄ 크기도 λ˜‘κ°™μ΄ 256 x 256 x 3이게 λ©λ‹ˆλ‹€. μ΄λ ‡κ²Œ 같은 크기의 인코더-디코더 λ ˆμ΄μ–΄κ°€ κ²°ν•©ν•˜μ—¬, 보닀 효과적이고 λΉ λ₯Έ μ„±λŠ₯을 λ°œνœ˜ν•  수 있게 ν•˜λŠ”κ²Œ U-NET의 νŠΉμ§•μž…λ‹ˆλ‹€.

이제 μŠ€ν‚΅ 컀λ„₯μ…˜μ΄ λ‚΄μž₯된 U-NET μƒμ„±μž(Generator)λ₯Ό 섀계해보도둝 ν•˜κ² μŠ΅λ‹ˆλ‹€. μ•žμž₯μ—μ„œ μ„€λͺ…ν–ˆλ“―이, GAN λͺ¨λΈμ—λŠ” U-NET μƒμ„±μž(Generator)κ°€ 있으며, μŠ€ν‚΅ 컀λ„₯μ…˜μ„ 톡해 인코더-디코더 λ ˆμ΄μ–΄ κ°„μ˜ λ‘œμ»¬λΌμ΄μ§•μ„ ν•΄μ€λ‹ˆλ‹€.

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# U-NET 생성

class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)

        return x


class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()
        
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)

이제 κ΅¬λΆ„μž(Discriminator)λ₯Ό μƒμ„±ν•΄λ³΄κ² μŠ΅λ‹ˆλ‹€.

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

이제 μƒμ„±μž(Generator)와 λΆ„λ³„μž(Discriminator)의 ꡬ쑰λ₯Ό μ‚΄νŽ΄λ³΄λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€.

GeneratorUNet().apply(weights_init_normal)
GeneratorUNet(
  (down1): UNetDown(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
  )
  (down2): UNetDown(
    (model): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (down3): UNetDown(
    (model): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (down4): UNetDown(
    (model): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Dropout(p=0.5, inplace=False)
    )
  )
  (down5): UNetDown(
    (model): Sequential(
      (0): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Dropout(p=0.5, inplace=False)
    )
  )
  (down6): UNetDown(
    (model): Sequential(
      (0): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Dropout(p=0.5, inplace=False)
    )
  )
  (down7): UNetDown(
    (model): Sequential(
      (0): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Dropout(p=0.5, inplace=False)
    )
  )
  (down8): UNetDown(
    (model): Sequential(
      (0): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
      (2): Dropout(p=0.5, inplace=False)
    )
  )
  (up1): UNetUp(
    (model): Sequential(
      (0): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
      (3): Dropout(p=0.5, inplace=False)
    )
  )
  (up2): UNetUp(
    (model): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
      (3): Dropout(p=0.5, inplace=False)
    )
  )
  (up3): UNetUp(
    (model): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
      (3): Dropout(p=0.5, inplace=False)
    )
  )
  (up4): UNetUp(
    (model): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
      (3): Dropout(p=0.5, inplace=False)
    )
  )
  (up5): UNetUp(
    (model): Sequential(
      (0): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
    )
  )
  (up6): UNetUp(
    (model): Sequential(
      (0): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
    )
  )
  (up7): UNetUp(
    (model): Sequential(
      (0): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
    )
  )
  (final): Sequential(
    (0): Upsample(scale_factor=2.0, mode=nearest)
    (1): ZeroPad2d(padding=(1, 0, 1, 0), value=0.0)
    (2): Conv2d(128, 3, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (3): Tanh()
  )
)
Discriminator().apply(weights_init_normal)
Discriminator(
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): ZeroPad2d(padding=(1, 0, 1, 0), value=0.0)
    (12): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
  )
)

μƒμ„±μžμ™€ λΆ„λ³„μž μž‘λ™μ›λ¦¬λ₯Ό μ‹œκ°ν™” ν•œ κ·Έλ¦Ό 4-1λ₯Ό λ³΄μ‹œλ©΄, μƒμ„±μžμ— μ˜ν•΄ μƒμ„±λœ μ΄λ―Έμ§€λŠ” μ•„μ›ƒν’‹μœΌλ‘œ 인풋 이미지와 쌍으둜 이루어져 λΆ„λ³„μžμ— μ˜ν•΄ μ–Όλ§ˆλ‚˜ λΉ„μŠ·ν•œ 지 νŒλ‹¨ν•˜κ²Œ λ©λ‹ˆλ‹€. λ˜ν•œ 인풋 이미지와 νƒ€κ²Ÿ 이미지도 λ™μ‹œμ— μž…λ ₯이 λ˜μ–΄ λΆ„λ³„μžμ— μ˜ν•΄ λΉ„κ΅λ˜κ²Œ λ©λ‹ˆλ‹€. 이 두 μŒμ„ λΉ„κ΅ν•œ κ²°κ³Ό 값을 λΆ„λ³„μž κ°’(Discriminator weights)인데, 이 과정을 κ±°μΉ˜λ©΄μ„œ μ—…λ°μ΄νŠΈ 되게 λ©λ‹ˆλ‹€.

λΆ„λ³„μž 값이 κ°±μ‹ λ˜λ©΄, μƒμ„±μž κ°’(Generator weights)도 μ•„λž˜μ˜ 과정을 톡해 κ°±μ‹ λ˜λ©΄μ„œ μƒˆλ‘œμš΄ 이미지λ₯Ό μƒμ„±ν•˜κ²Œ λ©λ‹ˆλ‹€. λͺ¨λΈ ν•™μŠ΅μ€ μ΄λŸ¬ν•œ 과정을 계속 λ°˜λ³΅ν•˜κ²Œ λ©λ‹ˆλ‹€.

이제 νŒŒλΌλ―Έν„°λ₯Ό μ§€μ •ν•˜κ³  pix2pix λͺ¨λΈμ„ ν•™μŠ΅ν•΄λ³΄λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€. μ—¬κΈ°μ„œ n_epoch은 총 ν•™μŠ΅ν•  에폭 횟수이고, lrλŠ” ν•™μŠ΅ 손싀값(Learning Loss)을 μ˜λ―Έν•©λ‹ˆλ‹€. checkpoint_interval은 ν•™μŠ΅μ€‘ λͺ¨λΈμ˜ κ°€μ€‘μΉ˜κ°€ μ €μž₯λ˜λŠ” κ°„κ²©μž…λ‹ˆλ‹€.

n_epochs = 100
dataset_name = "Victorian400"
lr = 0.0002
b1 = 0.5                    # adam: decay of first order momentum of gradient
b2 = 0.999                  # adam: decay of first order momentum of gradient
decay_epoch = 100           # epoch from which to start lr decay
#n_cpu = 8                   # number of cpu threads to use during batch generation
channels = 3                # number of image channels
checkpoint_interval = 20    # interval between model checkpoints
os.makedirs("images/%s/val" % dataset_name, exist_ok=True)
os.makedirs("images/%s/test" % dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % dataset_name, exist_ok=True)

# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, img_height // 2 ** 4, img_width // 2 ** 4)

# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()

cuda = True if torch.cuda.is_available() else False

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

sample_images ν•¨μˆ˜ μ •μ˜ 뢀뢄을 λ³΄μ‹œλ©΄, gray, color, outputκ°€ μžˆλŠ”λ°, μ €λŠ” grayλ₯Ό 흑백사진, colorλ₯Ό μ»¬λŸ¬μ‚¬μ§„, outputλ₯Ό 흑백을 μ»¬λŸ¬ν™”ν•œ μ‚¬μ§„μœΌλ‘œ μ •μ˜ν•˜μ˜€μŠ΅λ‹ˆλ‹€. grayκ°€ color와 λΉ„κ΅λ˜λ©΄μ„œ ν•™μŠ΅μ΄ 되고, 이λ₯Ό λ°”νƒ•μœΌλ‘œ outputλ₯Ό μƒμ„±ν•˜κ²Œ λ©λ‹ˆλ‹€.

def sample_images(epoch, loader, mode):
    imgs = next(iter(loader))
    gray = Variable(imgs["A"].type(Tensor))
    color = Variable(imgs["B"].type(Tensor))
    output = generator(gray)    
    
    gray_img = torchvision.utils.make_grid(gray.data, nrow=6) 
    color_img = torchvision.utils.make_grid(color.data, nrow=6)  
    output_img = torchvision.utils.make_grid(output.data, nrow=6)

    rows = 3
    cols = 1

    ax1 = fig.add_subplot(rows, cols, 1)
    ax1.imshow(reNormalize(gray_img.cpu(), gray_mean, gray_std)) 
    ax1.set_title('gray')

    ax2 = fig.add_subplot(rows, cols, 2)
    ax2.imshow(reNormalize(color_img.cpu(), color_mean, color_std))
    ax2.set_title('color')  

    ax3 = fig.add_subplot(rows, cols, 3)
    ax3.imshow(reNormalize(output_img.cpu(), color_mean, color_std))
    ax3.set_title('output')  

    plt.show()
    fig.savefig("images/%s/%s/epoch_%s.png" % (dataset_name, mode, epoch), pad_inches=0)

4.4 λͺ¨λΈ ν•™μŠ΅ΒΆ

이제 μ§€μ •ν•œ epoch만큼 ν•™μŠ΅μ„ μ‹œμž‘ν•΄λ³΄λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€.

# ----------
#  Training
# ----------

for epoch in range(1, n_epochs+1):
    start_time = datetime.datetime.now()
    for i, batch in enumerate(train_loader):

        # Model inputs
        gray = Variable(batch["A"].type(Tensor))
        color = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((gray.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((gray.size(0), *patch))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        output = generator(gray)
        pred_fake = discriminator(output, gray)
        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(output, color)

        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel

        loss_G.backward()

        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(color, gray)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = discriminator(output.detach(), gray)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()

        epoch_time = datetime.datetime.now() - start_time

    if (epoch) % checkpoint_interval == 0:
        fig = plt.figure(figsize=(18, 18))
        sample_images(epoch, train_loader, 'val')

        torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (dataset_name, epoch))
        torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (dataset_name, epoch))

        print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s" % (epoch, 
                                                                                                    n_epochs, 
                                                                                                    i+1, 
                                                                                                    len(train_loader), 
                                                                                                    loss_D.item(), 
                                                                                                    loss_G.item(), 
                                                                                                    loss_pixel.item(), 
                                                                                                    loss_GAN.item(), 
                                                                                                    epoch_time))
     
../../_images/Ch4-pix2pix_35_0.png
[Epoch 20/100] [Batch 34/34] [D loss: 0.003363] [G loss: 36.174591, pixel: 0.351734, adv: 1.001226] ETA: 0:00:15.931600
../../_images/Ch4-pix2pix_35_2.png
[Epoch 40/100] [Batch 34/34] [D loss: 0.010069] [G loss: 22.281427, pixel: 0.212988, adv: 0.982629] ETA: 0:00:15.914369
../../_images/Ch4-pix2pix_35_4.png
[Epoch 60/100] [Batch 34/34] [D loss: 0.001813] [G loss: 29.513786, pixel: 0.284740, adv: 1.039806] ETA: 0:00:15.915187
../../_images/Ch4-pix2pix_35_6.png
[Epoch 80/100] [Batch 34/34] [D loss: 0.001550] [G loss: 18.294107, pixel: 0.172993, adv: 0.994772] ETA: 0:00:15.893250
../../_images/Ch4-pix2pix_35_8.png
[Epoch 100/100] [Batch 34/34] [D loss: 0.399534] [G loss: 22.823000, pixel: 0.224467, adv: 0.376251] ETA: 0:00:15.921102

μœ„μ˜ μƒ˜ν”Œ 사진듀을 λ³΄μ‹œλ©΄, μœ„μ—μ„œ μ•„λž˜ μˆœμ„œλŒ€λ‘œ 흑백-타켓-아웃풋 μ΄λ―Έμ§€μž…λ‹ˆλ‹€. ν™•μ‹€νžˆ μ—ν­μˆ˜κ°€ λŠ˜μ–΄λ‚¨μœΌλ‘œμ¨ ν•™μŠ΅ νš¨κ³Όκ°€ λ‚˜νƒ€λ‚˜κ³  μžˆλŠ” κ±Έ 확인 ν•  수 μžˆμŠ΅λ‹ˆλ‹€. μ΄λ ‡κ²Œ μƒ˜ν”Œλ§ 된 이미지λ₯Ό ν™•μΈν•˜λ©΄μ„œ μ μ ˆν•œ λ°°μΉ˜μ‚¬μ΄μ¦ˆμ™€ μ—ν­μˆ˜λ₯Ό 찾을 수 μžˆμŠ΅λ‹ˆλ‹€.

4.5 예츑 및 μ„±λŠ₯ 평가¢

이제 ν•™μŠ΅λœ λͺ¨λΈμ„ μ΄μš©ν•΄ 6μž₯의 ν…ŒμŠ€νŠΈμ…‹μœΌλ‘œ μ‹€ν—˜ν•΄λ³΄λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€.

test_root = root + 'test/'
test_batch_size = 6

test_loader = DataLoader(
    VictorianDataset(test_root, color_transforms_=color_transforms_, gray_transforms_=gray_transforms_),
    batch_size=test_batch_size,
    shuffle=False
)

ν…ŒμŠ€νŠΈμ…‹ 이미지 파일이 잘 좜λ ₯λ˜λŠ”μ§€ 확인해보도둝 ν•˜κ² μŠ΅λ‹ˆλ‹€.

fig = plt.figure(figsize=(10, 5))
rows = 1 
cols = 2

for X in test_loader:

    print(X['A'].shape, X['B'].shape)
    ax1 = fig.add_subplot(rows, cols, 1)
    ax1.imshow(reNormalize(X["A"][0], gray_mean, gray_std)) 
    ax1.set_title('gray img')

    ax2 = fig.add_subplot(rows, cols, 2)
    ax2.imshow(reNormalize(X["B"][0], color_mean, color_std))
    ax2.set_title('color img')    

    plt.show()
    break
torch.Size([6, 3, 256, 256]) torch.Size([6, 3, 256, 256])
../../_images/Ch4-pix2pix_41_1.png

이제 ν•™μŠ΅λœ λͺ¨λΈμ„ λΆˆλŸ¬μ™€ ν…ŒμŠ€νŠΈμ…‹ 이미지 νŒŒμΌλ“€μ„ μ˜ˆμΈ‘ν•΄λ³΄λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€. μ•„λž˜ μ½”λ“œλŠ” μ΅œλŒ€ μ—ν­μˆ˜λ‘œ ν•™μŠ΅λœ λͺ¨λΈμ„ μ μš©ν•©λ‹ˆλ‹€. μ›ν•˜λŠ” μ—ν­μˆ˜λ₯Ό n_epochs에 μ§€μ •ν•˜λ©΄, ν•΄λ‹Ή μ—ν­μˆ˜μ˜ ν•™μŠ΅λœ λͺ¨λΈμ„ 뢈러올 수 μžˆμŠ΅λ‹ˆλ‹€.

generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (dataset_name, n_epochs)))
discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (dataset_name, n_epochs)))
<All keys matched successfully>
generator.eval()
discriminator.eval()

fig = plt.figure(figsize=(18,10))
sample_images(n_epochs, test_loader, 'test')
../../_images/Ch4-pix2pix_44_0.png

μœ„μ—μ„œ μ•„λž˜ μˆœμ„œλŒ€λ‘œ 흑백-타켓-아웃풋 ν…ŒμŠ€νŠΈ μ΄λ―Έμ§€μž…λ‹ˆλ‹€. μ–΄λ–€ 사진은 원본보닀 좜λ ₯이 더 잘 λ˜λŠ” 것을 확인 ν•  수 μžˆμŠ΅λ‹ˆλ‹€. μƒμ„±μž(Generator)에 U-NET이 μΆ”κ°€λœ cGAN λͺ¨λΈμ΄ GAN λͺ¨λΈλ³΄λ‹€ 색채λ₯Ό μ˜ˆμΈ‘ν•˜λŠ” 점에 μžˆμ–΄μ„œ, 보닀 λ‚˜μ€ 결과물을 μƒμ„±ν•˜λŠ” 것을 확인해 λ³Ό 수 μžˆμ—ˆμŠ΅λ‹ˆλ‹€.

λͺ¨λΈ ꡬ쑰가 ꡬ체적인 λͺ©ν‘œμ— 맞좰 잘 μ§œμ—¬μ Έ μžˆλ‹€λ©΄, μ΄λ ‡κ²Œ μž‘μ€ λ°μ΄ν„°μ…‹μœΌλ‘œλ„ 쒋은 결과물을 λ§Œλ“€ 수 μžˆμŠ΅λ‹ˆλ‹€. λ¬Όλ‘  λ§Žμ€ μ–‘μ˜ 질 쒋은 데이터가 μΆ”κ°€λœλ‹€λ©΄, ν–₯상 된 κ²°κ³Όλ₯Ό 얻을 수 μžˆμŠ΅λ‹ˆλ‹€. 이제 ν•™μŠ΅λœ λͺ¨λΈμ„ λ°”νƒ•μœΌλ‘œ μ—¬λŸ¬λΆ„μ˜ 흑백 그림을 μ±„μƒ‰ν•΄λ³΄λŠ”κ²Œ μ–΄λ–¨κΉŒμš”?