Open In Colab

3. GANยถ

Open In Colab

2์žฅ์—์„œ๋Š” Victorain400 ๋ฐ์ดํ„ฐ์…‹์˜ ์ด๋ฏธ์ง€๋ฅผ ํด๋”๋ณ„๋กœ ์‹œ๊ฐํ™” ํ•ด๋ณด์•˜์Šต๋‹ˆ๋‹ค. 3์žฅ์—์„œ๋Š” ๋ณธ๊ฒฉ์ ์œผ๋กœ GAN์„ ์ด์šฉํ•˜์—ฌ ํ‘๋ฐฑ ์ด๋ฏธ์ง€๋ฅผ ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜์‹œํ‚ค๋Š” ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

3.1์ ˆ์—์„œ๋Š” ๋ชจ๋ธ๋ง์— ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ , 3.2์ ˆ์—์„œ๋Š” ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค๋ฅผ ์ •์˜ํ•˜์—ฌ ๋ฐ์ดํ„ฐ ๋กœ๋”๊นŒ์ง€ ์ •์˜ํ•˜๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. 3.3์ ˆ์—์„œ๋Š” Generator ํด๋ž˜์Šค์™€ Discriminator ํด๋ž˜์Šค๋ฅผ ์ •์˜ํ•˜์—ฌ ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  3.4์ ˆ์—์„œ๋Š” ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์„ค์ •ํ•˜์—ฌ ๊ตฌ์ถ•๋œ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚จ ํ›„, 3.5์ ˆ์—์„œ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ์˜ˆ์ธก๊ณผ ์„ฑ๋Šฅํ‰๊ฐ€๋ฅผ ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

3.1 ๋ฐ์ดํ„ฐ์…‹ ๋‹ค์šด๋กœ๋“œยถ

2.1์ ˆ์—์„œ ๋‚˜์˜จ ์ฝ”๋“œ๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋ชจ๋ธ์— ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ์…‹์„ ๋‚ด๋ ค๋ฐ›๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ฐ€์งœ์—ฐ๊ตฌ์†Œ ๊นƒํ—ˆ๋ธŒ์˜ Tutorial-Book-Utils๋ฅผ cloneํ•˜๊ณ  PL_data_loader.py ํŒŒ์ผ๋กœ Victorian400 ๋ฐ์ดํ„ฐ์…‹์„ ๋‚ด๋ ค๋ฐ›์•„ ์••์ถ•์„ ํ‘ธ๋Š” ์ˆœ์„œ์ž…๋‹ˆ๋‹ค.

!git clone https://github.com/Pseudo-Lab/Tutorial-Book-Utils
!python Tutorial-Book-Utils/PL_data_loader.py --data GAN-Colorization
!unzip -q Victorian400-GAN-colorization-data.zip

์ด๋ฒˆ์—๋Š” 3์žฅ์—์„œ ์‚ฌ์šฉํ•  ํŒจํ‚ค์ง€๋“ค์„ ๋ถˆ๋Ÿฌ์˜ค๊ฒ ์Šต๋‹ˆ๋‹ค. os์™€ glob๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ์ง€์ •ํ•  ์ˆ˜ ์žˆ๋Š” ํŒจํ‚ค์ง€์ด๋ฉฐ, datetime์€ ๋‚ ์งœ, ์‹œ๊ฐ„์„ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๋Š” ํŒจํ‚ค์ง€์ž…๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  numpy๋Š” ์ˆ˜์น˜ ์—ฐ์‚ฐ์— ์‚ฌ์šฉ๋˜๋Š” ์„ ํ˜• ๋Œ€์ˆ˜ ํŒจํ‚ค์ง€์ด๋ฉฐ, matplotlib๊ณผ PIL์˜ Image๋Š” ์ด๋ฏธ์ง€ ํŒŒ์ผ์„ ์‹œ๊ฐํ™”ํ•˜๋Š”๋ฐ ์‚ฌ์šฉํ•˜๋Š” ํŒจํ‚ค์ง€์ž…๋‹ˆ๋‹ค. ๊ทธ ์™ธ์˜ torch์™€ torchvision ํŒจํ‚ค์ง€๋Š” ๋ฐ์ดํ„ฐ์…‹ ์ •์˜์™€ ๋ชจ๋ธ ๊ตฌ์ถ•์— ์‚ฌ์šฉ๋˜๋Š” Torch ๊ธฐ๋ฐ˜์˜ ํŒจํ‚ค์ง€๋“ค์ž…๋‹ˆ๋‹ค.

import os
import glob
import datetime
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

3.2 ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค ์ •์˜ยถ

ํŒŒ์ดํ† ์น˜ ๋ชจ๋ธ์— ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค๋ฅผ ์ •์˜ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ํ‘๋ฐฑ, ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€์˜ ํด๋” ๊ฒฝ๋กœ์™€ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜(transforms)์„ ์ธ์ž๋กœ ๋ฐ›์•„ Image.open ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•ด๋‹น ์ด๋ฏธ์ง€๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. ํ‘๋ฐฑ ์ด๋ฏธ์ง€๋Š” .convert("L")์„ ์‚ฌ์šฉํ•˜์—ฌ ๋‹จ์ผ ์ฑ„๋„๋กœ, ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋Š” .convert("RGB")์„ ์‚ฌ์šฉํ•˜์—ฌ 3์ฑ„๋„๋กœ ๋ฐ›์•„์ค๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ด๋ฏธ์ง€ ๋ณ€ํ™˜์„ ๊ฑฐ์ณ ๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ๋กœ ํ‘๋ฐฑ ์ด๋ฏธ์ง€์™€ ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋ฅผ โ€œAโ€, โ€œBโ€์— ๊ฐ๊ฐ ๋ฐ˜ํ™˜์‹œ์ผœ์ค๋‹ˆ๋‹ค.

class VictorianDataset(Dataset):
    def __init__(self, root, color_transforms_=None, gray_transforms_=None):

        self.color_transforms = transforms.Compose(color_transforms_)
        self.gray_transforms = transforms.Compose(gray_transforms_)
        self.gray_files = sorted(glob.glob(os.path.join(root, 'gray') + "/*.*"))
        self.color_files = sorted(glob.glob(os.path.join(root, 'resized') + "/*.*"))
     
    def __getitem__(self, index):

        gray_img = Image.open(self.gray_files[index % len(self.gray_files)]).convert("L")
        color_img = Image.open(self.color_files[index % len(self.color_files)]).convert("RGB")
    
        gray_img = self.gray_transforms(gray_img)
        color_img = self.color_transforms(color_img)

        return {"A": gray_img, "B": color_img}

    def __len__(self):
        return len(self.gray_files)

์ด๋ฏธ์ง€ ๊ฒฝ๋กœ์™€ ์‚ฌ์ด์ฆˆ, ๊ฐ ๋ฐ์ดํ„ฐ์…‹์˜ ๋ฐฐ์น˜์‚ฌ์ด์ฆˆ๋ฅผ ์ง€์ •ํ•ด์ค๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ์‚ฌ์ด์ฆˆ์™€ ๋ฐฐ์น˜์‚ฌ์ด์ฆˆ๋Š” ๊ฐ์ž์˜ ํ™˜๊ฒฝ(์ตœ๋Œ€ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰)์— ๋งž์ถฐ ์„ค์ •ํ•ด์ฃผ๋ฉด ๋ฉ๋‹ˆ๋‹ค. ๋‹จ, ์ดํ›„์— ๋‚˜์˜ฌ ๋ชจ๋ธ ๊ตฌ์กฐ๊ฐ€ 256์œผ๋กœ ๊ณ ์ •๋˜์–ด ์žˆ์–ด ๋‹ค๋ฅธ ์ด๋ฏธ์ง€ ์‚ฌ์ด์ฆˆ๋ฅผ ์›ํ•˜์‹ค ๊ฒฝ์šฐ ๋ชจ๋ธ ๊ตฌ์กฐ๋„ ์•ฝ๊ฐ„์˜ ์ˆ˜์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

root = ''
test_root = root + 'test/'

img_height = 256
img_width = 256

batch_size = 12
test_batch_size = 6

gpu = 0

๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค์˜ ์ธ์ž๋กœ ๋„ฃ์„ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜(transform)์„ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค. ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ํŒŒ์ดํ† ์น˜ ๋ชจ๋ธ์— ๋„ฃ๊ธฐ ์œ„ํ•ด tensor ํƒ€์ž…์œผ๋กœ ๋ฐ”๊ฟ”์ฃผ๊ณ  2.4์ ˆ์—์„œ ๊ตฌํ•œ ํ‰๊ท ๊ณผ ํ‘œ์ค€ํŽธ์ฐจ๋กœ normalize๋ฅผ ํ•ด์ฃผ๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด ์™ธ์—๋„ ๋ชจ๋ธ ๊ตฌ์กฐ์— ๋”ฐ๋ผ Resize๋ฅผ ํ•ด์ฃผ๊ฑฐ๋‚˜, ๋„๋ฉ”์ธ์— ๋”ฐ๋ผ RandomCrop(๋žœ๋ค์œผ๋กœ ์ž๋ฅด๊ธฐ), RandomVerticalFlip(๋žœ๋ค์œผ๋กœ ์ˆ˜ํ‰ ๋’ค์ง‘๊ธฐ) ๋“ฑ ๋‹ค์–‘ํ•œ ์ด๋ฏธ์ง€ ๋ณ€ํ˜•์„ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

color_mean = [0.58090717, 0.52688643, 0.45678478]
color_std = [0.25644188, 0.25482641, 0.24456465]
gray_mean = [0.5350533]
gray_std = [0.25051587]

color_transforms_ = [
    transforms.Resize(size=(img_height, img_width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=color_mean, std=color_std),
]

gray_transforms_ = [
    transforms.Resize(size=(img_height, img_width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=gray_mean, std=gray_std),
]

์œ„์—์„œ ์ •์˜ํ•œ ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค์™€ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜(transform)์„ DataLoaderํ•จ์ˆ˜์— ๋„ฃ์–ด์ค๋‹ˆ๋‹ค.

train_loader = DataLoader(
    VictorianDataset(root, color_transforms_=color_transforms_, gray_transforms_=gray_transforms_),
    batch_size=batch_size,
    shuffle=True
)

๋ฐ์ดํ„ฐ ๋กœ๋”๊ฐ€ ์ œ๋Œ€๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋Š”์ง€ ์•Œ์•„๋ณด๊ธฐ ์œ„ํ•ด ์‹œ๊ฐํ™”ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ๋กœ๋”์—๋Š” ์ •๊ทœํ™”๋œ ์ด๋ฏธ์ง€๊ฐ€ ์ €์žฅ๋˜์–ด ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ •๊ทœํ™”๋ฅผ ๋ณต์›์‹œํ‚ค๊ณ  ์‹œ๊ฐํ™”ํ•ด์ฃผ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ reNormalizeํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ๋‹ค์‹œ ํ‘œ์ค€ํŽธ์ฐจ๋ฅผ ๊ณฑํ•˜๊ณ  ํ‰๊ท ์„ ๋”ํ•ด์ค๋‹ˆ๋‹ค. ์ด๋•Œ .transpose()๋Š” ์ถ•์˜ ์ˆœ์„œ๋ฅผ ๋ฐ”๊ฟ”์ฃผ๊ณ  .clip(min, max)์€ min๋ณด๋‹ค ์ž‘์œผ๋ฉด min์œผ๋กœ, max๋ณด๋‹ค ํฌ๋ฉด max๋กœ ๋ฐ”๊ฟ”์ฃผ๋Š” ์—ญํ• ์„ ํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ, โ€œAโ€์— ํ•ด๋‹นํ•˜๋Š” ํ‘๋ฐฑ ์ด๋ฏธ์ง€๋Š” ๋‹จ์ผ ์ฑ„๋„์ด๊ธฐ ๋•Œ๋ฌธ์— .reshape()์„ ํ†ตํ•ด 2์ฐจ์›์œผ๋กœ ๋ฐ”๊ฟ”์ฃผ๊ณ  cmap=gray๋ฅผ ์„ค์ •ํ•ด์ฃผ์–ด์•ผ ์ด๋ฏธ์ง€๋ฅผ ์ œ๋Œ€๋กœ ์ถœ๋ ฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def reNormalize(img, mean, std):
    img = img.numpy().transpose(1, 2, 0)
    img = img * std + mean
    img = img.clip(0, 1)
    return img  
fig = plt.figure(figsize=(10,5))
rows = 1 
cols = 2

for X in train_loader:
    
    print(X['A'].shape, X['B'].shape)
    ax1 = fig.add_subplot(rows, cols, 1)
    ax1.imshow(reNormalize(X["A"][0], gray_mean, gray_std).reshape(img_height, img_width), cmap='gray') 
    ax1.set_title('gray img')

    ax2 = fig.add_subplot(rows, cols, 2)
    ax2.imshow(reNormalize(X["B"][0], color_mean, color_std))
    ax2.set_title('color img')    

    plt.show()
    break
torch.Size([12, 1, 256, 256]) torch.Size([12, 3, 256, 256])
../../_images/Ch3-GAN_21_1.svg

๋ฐ์ดํ„ฐ ๋กœ๋”๊ฐ€ ์ œ๋Œ€๋กœ ๊ตฌ์„ฑ๋˜์–ด ํ‘๋ฐฑ ์ด๋ฏธ์ง€์™€ ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๊ฐ€ ์ž˜ ์ถœ๋ ฅ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋กœ๋”๋„ ์ •์˜ํ•˜๊ณ  ์‹œ๊ฐํ™”ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

test_loader = DataLoader(
    VictorianDataset(test_root, color_transforms_=color_transforms_, gray_transforms_=gray_transforms_),
    batch_size=test_batch_size,
    shuffle=False
)
fig = plt.figure(figsize=(10,5))
rows = 1 
cols = 2

for X in test_loader:
    
    print(X['A'].shape, X['B'].shape)
    ax1 = fig.add_subplot(rows, cols, 1)
    ax1.imshow(reNormalize(X["A"][0], gray_mean, gray_std).reshape(img_height, img_width), cmap='gray')
    ax1.set_title('gray img')

    ax2 = fig.add_subplot(rows, cols, 2)
    ax2.imshow(reNormalize(X["B"][0], color_mean, color_std))
    ax2.set_title('color img')    

    plt.show()
    break
torch.Size([6, 1, 256, 256]) torch.Size([6, 3, 256, 256])
../../_images/Ch3-GAN_24_1.svg

ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋กœ๋”๋„ ์ž˜ ๊ตฌ์„ฑ๋œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

3.3 ๋ชจ๋ธ ๊ตฌ์ถ•ยถ

์ด๋ฒˆ์—๋Š” ๋ฐ์ดํ„ฐ์…‹์„ ํ•™์Šต์‹œํ‚ฌ ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. 3.3์ ˆ์—์„œ๋Š” GAN ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ‘๋ฐฑ ์ด๋ฏธ์ง€๋ฅผ ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜์‹œ์ผœ๋ณผ ๊ฒƒ ์ž…๋‹ˆ๋‹ค. GAN ๋ชจ๋ธ์€ ์ƒ์„ฑ ๋ชจ๋ธ(Generator)๊ณผ ํŒ๋ณ„ ๋ชจ๋ธ(Discriminator)๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ๊ฐ์€ Conv2d์™€ ConvTranspose2d, MaxPool2d, ์ •๊ทœํ™”, ํ™œ์„ฑํ™”ํ•จ์ˆ˜๋“ค๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค.

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        """noise image + gray image"""
        self.conv1 = nn.Sequential(
            nn.Conv2d(2, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1)
        )
        
        self.maxpool = nn.MaxPool2d(2,2)
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64 * 2 , 3, 1, 1),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.1)
        )
        
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        
        self.conv1by1 = nn.Sequential(
            nn.Conv2d(64,64,1,1,0),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1)
        )
        
        
        self.conv = nn.Sequential(
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, input):
        output1 = self.conv1(input)
        pool1 = self.maxpool(output1)
        output2 = self.conv2(pool1)
        output3 = self.upsample(output2) + output1
        output4 = self.conv1by1(output3)
        out = self.conv(output4)        
        return out
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3,64,kernel_size = 4, stride = 2, padding = 1, bias = False),
            nn.LeakyReLU(0.2, inplace = True),
            
            nn.Conv2d(64,128,kernel_size = 4, stride = 2, padding = 1, bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            
            nn.Conv2d(128, 128, kernel_size = 1, stride = 1, padding = 0, bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            
            nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1, bias = False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace = True),
            
            nn.Conv2d(256, 512, kernel_size = 4, stride = 2, padding = 1, bias = False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace = True),
            
            nn.Conv2d(512, 1024, kernel_size = 4, stride = 2, padding = 1, bias = False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace = True),
            
            )
        
        
        self.fc = nn.Sequential(
            nn.Linear(1024 * 8 * 8 , 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, batch_size),
            nn.Sigmoid()
        )
        
    def forward(self, input, b_size):
        output = self.main(input)
        output = self.fc(output.view(b_size,-1))
        return output

ํ•™์Šต์‹œ ๋ฐฐ์น˜๋งˆ๋‹ค ๊ฐ€์ค‘์น˜๋ฅผ ์ดˆ๊ธฐํ™” ์‹œ์ผœ์ฃผ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ weights_init ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ์ƒ์„ฑ ๋ชจ๋ธ(Generator)๊ณผ ํŒ๋ณ„ ๋ชจ๋ธ(Discriminator)์— ๋ถ™์—ฌ์ค๋‹ˆ๋‹ค.

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:        
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1: 
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

๊ทธ๋ฆฌ๊ณ  ์ƒ์„ฑ ๋ชจ๋ธ(Generator)๊ณผ ํŒ๋ณ„ ๋ชจ๋ธ(Discriminator)์„ GPU์—์„œ ํ•™์Šต์‹œํ‚ค๊ธฐ ์œ„ํ•ด .cuda(gpu)๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ์„ ์ „๋‹ฌํ•ด์ค๋‹ˆ๋‹ค. ์ด ๋•Œ, gpu์—๋Š” ์œ„์—์„œ ์ง€์ •ํ•œ device๊ฐ€ ๋“ค์–ด๊ฐ‘๋‹ˆ๋‹ค. ์ถœ๋ ฅ๋˜๋Š” ๊ฒƒ์„ ํ†ตํ•ด ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Gener = Generator().cuda(gpu) if torch.cuda.is_available() else Generator()
Gener.apply(weights_init)
Generator(
  (conv1): Sequential(
    (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
  )
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
  )
  (upsample): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (conv1by1): Sequential(
    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
  )
  (conv): Sequential(
    (0): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Tanh()
  )
)
Discri = Discriminator().cuda(gpu) if torch.cuda.is_available() else Discriminator()
Discri.apply(weights_init)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): LeakyReLU(negative_slope=0.2, inplace=True)
    (14): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (15): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (fc): Sequential(
    (0): Linear(in_features=65536, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=1024, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=256, out_features=12, bias=True)
    (5): Sigmoid()
  )
)

3.4 ๋ชจ๋ธ ํ•™์Šตยถ

์ด์ œ ๊ตฌ์ถ•๋œ ๋ชจ๋ธ์„ ํ•™์Šต์‹œ์ผœ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ํ•™์Šต ํšŸ์ˆ˜๋Š” 50์œผ๋กœ ์ง€์ •ํ•˜๊ณ , ์ƒ์„ฑ๋ชจ๋ธ๊ณผ ํŒ๋ณ„๋ชจ๋ธ ๋ชจ๋‘ Adam ์ตœ์ ํ™” ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ฐ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ํŠœ๋‹ํ•˜์—ฌ ๋‹ค์–‘ํ•œ ์‹คํ—˜์„ ํ•ด๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•™์Šต์€ ์ƒ์„ฑ๋ชจ๋ธ, ํŒ๋ณ„๋ชจ๋ธ ์ˆœ์œผ๋กœ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค. ์ƒ์„ฑ๋ชจ๋ธ์—์„œ๋Š” ํŒ๋ณ„๋ชจ๋ธ์„ ์†์ผ ์ˆ˜ ์žˆ๋Š” ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋„๋ก ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ํŒ๋ณ„๋ชจ๋ธ์—์„œ๋Š” ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋ฅผ ์ง„์งœ(real)๋กœ, ํ‘๋ฐฑ ์ด๋ฏธ์ง€๋ฅผ ๊ฐ€์งœ(fake)๋กœ ์ธ์‹ํ•˜๋„๋ก ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

max_epoch = 50
optimizerD = torch.optim.Adam(Discri.parameters(), lr = 0.0002,betas = (0.5, 0.999))
optimizerG = torch.optim.Adam(Gener.parameters(), lr = 0.0002, betas = (0.5, 0.999))
for epoch in range(max_epoch):
    start_time = datetime.datetime.now()
    loss_D = 0.0
    for i, data in enumerate(train_loader):
        grays, color = data['A'], data['B']
        b_size = len(data['A'])
        

        ######## ์ƒ์„ฑ๋ชจ๋ธ(Generator) ํ•™์Šต ########
        ######## Train Generator ########
        noise = torch.randn(b_size, 1, img_height, img_width).uniform_(0,1)
        gray_noise = Variable(torch.cat([grays,noise],dim=1).cuda(gpu))
        fake_img = Gener(gray_noise)
        output = Discri(fake_img,b_size)
        g_loss = torch.mean((output-1)**2)
        

        ######## ์—ญ์ „ํŒŒ & ์ƒ์„ฑ๋ชจ๋ธ ์ตœ์ ํ™” ########
        ######## Backpropagation & Optimize G ########
        Discri.zero_grad()
        Gener.zero_grad()
        g_loss.backward()
        optimizerG.step()


        ######## ํŒ๋ณ„๋ชจ๋ธ(Discriminator) ํ•™์Šต ########
        ######## Train Discriminator ########
        color = Variable(color.cuda(gpu))
        noise = torch.randn(b_size, 1, img_height, img_width).uniform_(0,1)   
        gray_noise = Variable(torch.cat([grays,noise],dim=1).cuda(gpu))   


        ######## ํŒ๋ณ„๋ชจ๋ธ์ด ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋ฅผ ์ง„์งœ(real)๋กœ ์ธ์‹ํ•˜๋„๋ก ํ•™์Šต ########
        ######## Train d to recognize color image as real ########
        output = Discri(color,b_size)
        real_loss = torch.mean((output-1)**2)
        

        ######## ํŒ๋ณ„๋ชจ๋ธ์ด ํ‘๋ฐฑ ์ด๋ฏธ์ง€๋ฅผ ๊ฐ€์งœ(fake)๋กœ ์ธ์‹ํ•˜๋„๋ก ํ•™์Šต ########
        ######## Train d to recognize fake image as fake ########        
        fake_img = Gener(gray_noise)   
        output = Discri(fake_img,b_size)
        fake_loss = torch.mean(output**2)
        

        ######## ์—ญ์ „ํŒŒ & ํŒ๋ณ„๋ชจ๋ธ ์ตœ์ ํ™” ########
        ######## Backpropagation & Optimize D ########
        d_loss = real_loss + fake_loss
        Discri.zero_grad()
        Gener.zero_grad()
        d_loss.backward()
        optimizerD.step()        


        fake_img = torchvision.utils.make_grid(fake_img.data) 

    epoch_time = datetime.datetime.now() - start_time
    if (epoch + 1) % 5 == 0:
        torch.save(Gener.state_dict(), "generator_%d.pth" % (epoch+1))
        torch.save(Discri.state_dict(), "discriminator_%d.pth" % (epoch+1))

        print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s" % (epoch+1, max_epoch, i+1, len(train_loader), d_loss.item(), g_loss.item(), epoch_time))
        plt.imshow(reNormalize(fake_img.cpu(), color_mean, color_std))
        plt.show()      
[Epoch 5/50] [Batch 34/34] [D loss: 0.335707] [G loss: 0.952135] ETA: 0:00:15.254198
../../_images/Ch3-GAN_38_1.svg
[Epoch 10/50] [Batch 34/34] [D loss: 0.232681] [G loss: 0.925357] ETA: 0:00:16.087967
../../_images/Ch3-GAN_38_3.svg
[Epoch 15/50] [Batch 34/34] [D loss: 0.222960] [G loss: 0.742795] ETA: 0:00:14.794427
../../_images/Ch3-GAN_38_5.svg
[Epoch 20/50] [Batch 34/34] [D loss: 0.002736] [G loss: 0.968330] ETA: 0:00:15.417760
../../_images/Ch3-GAN_38_7.svg
[Epoch 25/50] [Batch 34/34] [D loss: 0.000585] [G loss: 0.999990] ETA: 0:00:15.662105
../../_images/Ch3-GAN_38_9.svg
[Epoch 30/50] [Batch 34/34] [D loss: 0.008619] [G loss: 0.997328] ETA: 0:00:15.388837
../../_images/Ch3-GAN_38_11.svg
[Epoch 35/50] [Batch 34/34] [D loss: 0.201541] [G loss: 0.998293] ETA: 0:00:15.455658
../../_images/Ch3-GAN_38_13.svg
[Epoch 40/50] [Batch 34/34] [D loss: 0.011544] [G loss: 0.999389] ETA: 0:00:15.419755
../../_images/Ch3-GAN_38_15.svg
[Epoch 45/50] [Batch 34/34] [D loss: 0.000001] [G loss: 0.999999] ETA: 0:00:15.365898
../../_images/Ch3-GAN_38_17.svg
[Epoch 50/50] [Batch 34/34] [D loss: 0.158762] [G loss: 0.997427] ETA: 0:00:15.460645
../../_images/Ch3-GAN_38_19.svg

5์—ํญ๋งˆ๋‹ค ์ƒ์„ฑ๋ชจ๋ธ๊ณผ ํŒ๋ณ„๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์ €์žฅํ•˜์—ฌ ๋‚˜์ค‘์— ํ•™์Šต๋œ ๋ชจ๋ธ์„ ์žฌ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๊ฐ ๋ชจ๋ธ์˜ ๋กœ์Šค๊ฐ’๊ณผ ์ƒ์„ฑ๋ชจ๋ธ์—์„œ ์ƒ์„ฑํ•œ ์ด๋ฏธ์ง€๋ฅผ ์ถœ๋ ฅํ•˜์—ฌ ์–ด๋–ป๊ฒŒ ํ•™์Šต๋˜๊ณ  ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.

์ƒ์„ฑ๋ชจ๋ธ์˜ ๋กœ์Šค์™€ ํŒ๋ณ„๋ชจ๋ธ์˜ ๋กœ์Šค์˜ ๊ด€๊ณ„๋ฅผ ํ™•์ธํ•˜์—ฌ ์ ์ ˆํ•œ ์—ํญ ์ˆ˜๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ ‡์ง€ ๋ชปํ•  ๊ฒฝ์šฐ, ํ•œ ๋ชจ๋ธ์— ์˜ค๋ฒ„ํ”ผํŒ…๋˜์–ด ๋” ์ด์ƒ์˜ ํ•™์Šต์ด ๋ฌด์˜๋ฏธํ•ด์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

3.5 ์˜ˆ์ธก ๋ฐ ์„ฑ๋Šฅ ํ‰๊ฐ€ยถ

ํ•™์Šต์ด ๋๋‚ฌ์œผ๋‹ˆ ์ƒˆ๋กœ์šด ํ‘๋ฐฑ ์ด๋ฏธ์ง€์— ์ ์šฉํ•˜์—ฌ ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋กœ ์˜ˆ์ธก(์ฑ„์ƒ‰)ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์œ„์—์„œ ํ•™์Šตํ•œ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋กœ๋”์˜ ์ด๋ฏธ์ง€๋ฅผ ๋„ฃ์–ด ์˜ˆ์ธกํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” 35์—ํญ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด์ฒ˜๋Ÿผ ์ถœ๋ ฅ๋˜๋Š” ์ด๋ฏธ์ง€๋ฅผ ๋ณด๋ฉด์„œ ์›ํ•˜์‹œ๋Š” ์—ํญ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถˆ๋Ÿฌ์™€๋„ ๋ฉ๋‹ˆ๋‹ค.

Gener.load_state_dict(torch.load("generator_35.pth" ))
Discri.load_state_dict(torch.load("discriminator_35.pth" ))
<All keys matched successfully>
Discri.eval()
Gener.eval()

fixed_noise = torch.randn(test_batch_size, 1, img_height, img_width).uniform_(0,1)

for i, data in enumerate(test_loader,0) :
    images, label = data['A'], data['B']

    if len(data['A']) != test_batch_size:
        continue

    grays = torch.from_numpy(np.resize(images.numpy(), (test_batch_size, 1, img_height, img_width)))    
    gray = Variable(torch.cat([grays,fixed_noise],dim = 1).cuda(gpu))
    
    output = Gener(gray)

    inputs = torchvision.utils.make_grid(grays)
    labels = torchvision.utils.make_grid(label)
    out = torchvision.utils.make_grid(output.data)

    print('==================input==================')
    plt.imshow(reNormalize(inputs.cpu(), gray_mean, gray_std))
    plt.show()
    print('==================target==================')
    plt.imshow(reNormalize(labels.cpu(), color_mean, color_std))
    plt.show()
    print('==================output==================')
    plt.imshow(reNormalize(out.cpu(), color_mean, color_std))
    plt.show()
==================input==================
../../_images/Ch3-GAN_43_1.svg
==================target==================
../../_images/Ch3-GAN_43_3.svg
==================output==================
../../_images/Ch3-GAN_43_5.svg

๊ฒฐ๊ณผ๋Š” ํ‘๋ฐฑ(input), ์ปฌ๋Ÿฌ(target), ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€(output)์ˆœ์œผ๋กœ ์ถœ๋ ฅ๋˜์–ด์ง‘๋‹ˆ๋‹ค. ์„ ๋ช…ํ•˜์ง„ ์•Š์ง€๋งŒ ๋น„์Šทํ•œ ์ƒ‰๊ฐ์ด ์–ด๋Š ์ •๋„ ์ฑ„์ƒ‰๋œ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ง€๊ธˆ๊นŒ์ง€ 3์žฅ์—์„œ ํ‘๋ฐฑ ์ด๋ฏธ์ง€๋ฅผ ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋กœ ๋ฐ”๊พธ๋Š” ์ƒ์„ฑ๋ชจ๋ธ GAN์„ ๊ตฌ์ถ•ํ•ด๋ณด์•˜์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ์žฅ์—์„œ๋Š” ์ฑ„์ƒ‰์— ๋” ํŠนํ™”๋˜์–ด์žˆ๋Š” pix2pix๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋น„๊ตํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.