4. RetinaNetยถ

Open In Colab

3์žฅ์—์„œ๋Š” ์ œ๊ณต๋œ ๋ฐ์ดํ„ฐ์— augmentation์„ ๊ฐ€ํ•˜๋Š” ๋ฐฉ๋ฒ•๊ณผ ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค๋ฅผ ๋งŒ๋“œ๋Š” ๋ฐฉ๋ฒ•์„ ํ™•์ธํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฒˆ ์žฅ์—์„œ๋Š” torchvision์—์„œ ์ œ๊ณตํ•˜๋Š” one-stage ๋ชจ๋ธ์ธ RetinaNet์„ ํ™œ์šฉํ•ด ์˜๋ฃŒ์šฉ ๋งˆ์Šคํฌ ๊ฒ€์ถœ ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

4.1์ ˆ๋ถ€ํ„ฐ 4.3์ ˆ๊นŒ์ง€๋Š” 2์žฅ๊ณผ 3์žฅ์—์„œ ํ™•์ธํ•œ ๋‚ด์šฉ์„ ๋ฐ”ํƒ•์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ํ›ˆ๋ จ์šฉ, ์‹œํ—˜์šฉ ๋ฐ์ดํ„ฐ๋กœ ๋‚˜๋ˆˆ ํ›„ ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค๋ฅผ ์ •์˜ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. 4.4์ ˆ์—์„œ๋Š” torchvision API๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๊ฒ ์Šต๋‹ˆ๋‹ค. 4.5์ ˆ์—์„œ๋Š” ์ „์ด ํ•™์Šต์„ ํ†ตํ•ด ๋ชจ๋ธ ํ•™์Šต์„ ์ง„ํ–‰ํ•œ ํ›„ 4.6์ ˆ์—์„œ ์˜ˆ์ธก๊ฐ’ ์‚ฐ์ถœ ๋ฐ ๋ชจ๋ธ ์„ฑ๋Šฅ์„ ํ™•์ธํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

4.1 ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œยถ

๋ชจ๋ธ๋ง ์‹ค์Šต์„ ์œ„ํ•ด 2.1์ ˆ์— ๋‚˜์˜จ ์ฝ”๋“œ๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ฒ ์Šต๋‹ˆ๋‹ค.

!git clone https://github.com/Pseudo-Lab/Tutorial-Book-Utils
!python Tutorial-Book-Utils/PL_data_loader.py --data FaceMaskDetection
!unzip -q Face\ Mask\ Detection.zip
Cloning into 'Tutorial-Book-Utils'...
remote: Enumerating objects: 12, done.
remote: Counting objects: 100% (12/12), done.
remote: Compressing objects: 100% (11/11), done.
remote: Total 12 (delta 1), reused 2 (delta 0), pack-reused 0
Unpacking objects: 100% (12/12), done.
Face Mask Detection.zip is done!

4.2 ๋ฐ์ดํ„ฐ ๋ถ„๋ฆฌยถ

3.3์ ˆ์—์„œ ํ™•์ธํ•œ ๋ฐ์ดํ„ฐ ๋ถ„๋ฆฌ ๋ฐฉ๋ฒ•์„ ํ™œ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„๋ฆฌํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

import os
import random
import numpy as np
import shutil

print(len(os.listdir('annotations')))
print(len(os.listdir('images')))

!mkdir test_images
!mkdir test_annotations


random.seed(1234)
idx = random.sample(range(853), 170)

for img in np.array(sorted(os.listdir('images')))[idx]:
    shutil.move('images/'+img, 'test_images/'+img)

for annot in np.array(sorted(os.listdir('annotations')))[idx]:
    shutil.move('annotations/'+annot, 'test_annotations/'+annot)

print(len(os.listdir('annotations')))
print(len(os.listdir('images')))
print(len(os.listdir('test_annotations')))
print(len(os.listdir('test_images')))
853
853
683
683
170
170

4.3 ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค ์ •์˜ยถ

ํŒŒ์ดํ† ์น˜ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„  ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค๋ฅผ ์ •์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. torchvision์—์„œ ์ œ๊ณตํ•˜๋Š” ๊ฐ์ฒด ํƒ์ง€ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๊ธฐ ์œ„ํ•œ ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค์˜ __getitem__ ๋ฉ”์„œ๋“œ๋Š” ์ด๋ฏธ์ง€ ํŒŒ์ผ๊ณผ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ์ขŒํ‘œ๋ฅผ ๋ฐ˜ํ™˜ ํ•ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค๋ฅผ 3์žฅ์—์„œ ํ™œ์šฉํ•œ ์ฝ”๋“œ๋ฅผ ์‘์šฉํ•ด ์•„๋ž˜์™€ ๊ฐ™์ด ์ •์˜ ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

import os
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
from bs4 import BeautifulSoup
from PIL import Image
import cv2
import numpy as np
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
from matplotlib import pyplot as plt
import os

def generate_box(obj):
    
    xmin = float(obj.find('xmin').text)
    ymin = float(obj.find('ymin').text)
    xmax = float(obj.find('xmax').text)
    ymax = float(obj.find('ymax').text)
    
    return [xmin, ymin, xmax, ymax]

def generate_label(obj):

    if obj.find('name').text == "with_mask":

        return 1

    elif obj.find('name').text == "mask_weared_incorrect":

        return 2

    return 0

def generate_target(file): 
    with open(file) as f:
        data = f.read()
        soup = BeautifulSoup(data, "html.parser")
        objects = soup.find_all("object")

        num_objs = len(objects)

        boxes = []
        labels = []
        for i in objects:
            boxes.append(generate_box(i))
            labels.append(generate_label(i))

        boxes = torch.as_tensor(boxes, dtype=torch.float32) 
        labels = torch.as_tensor(labels, dtype=torch.int64) 
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        
        return target

def plot_image_from_output(img, annotation):
    
    img = img.cpu().permute(1,2,0)
    
    rects = []

    for idx in range(len(annotation["boxes"])):
        xmin, ymin, xmax, ymax = annotation["boxes"][idx]

        if annotation['labels'][idx] == 0 :
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
        
        elif annotation['labels'][idx] == 1 :
            
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
            
        else :
        
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='orange',facecolor='none')

        rects.append(rect)

    return img, rects

class MaskDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.imgs = list(sorted(os.listdir(self.path)))
        self.transform = transform
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        file_image = self.imgs[idx]
        file_label = self.imgs[idx][:-3] + 'xml'
        img_path = os.path.join(self.path, file_image)
        
        if 'test' in self.path:
            label_path = os.path.join("test_annotations/", file_label)
        else:
            label_path = os.path.join("annotations/", file_label)

        img = Image.open(img_path).convert("RGB")
        target = generate_target(label_path)
        
        to_tensor = torchvision.transforms.ToTensor()

        if self.transform:
            img, transform_target = self.transform(np.array(img), np.array(target['boxes']))
            target['boxes'] = torch.as_tensor(transform_target)

        # tensor๋กœ ๋ณ€๊ฒฝ
        img = to_tensor(img)


        return img, target

def collate_fn(batch):
    return tuple(zip(*batch))

dataset = MaskDataset('images/')
test_dataset = MaskDataset('test_images/')

data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, collate_fn=collate_fn)

์ตœ์ข…์ ์œผ๋กœ ํ›ˆ๋ จ์šฉ ๋ฐ์ดํ„ฐ์™€ ์‹œํ—˜์šฉ ๋ฐ์ดํ„ฐ๋ฅผ batch ๋‹จ์œ„๋กœ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๊ฒŒ torch.utils.data.DataLoader ํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•ด data_loader์™€ test_data_loader๋ฅผ ๊ฐ๊ฐ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

4.4 ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐยถ

torchvision์—์„œ๋Š” ๊ฐ์ข… ์ปดํ“จํ„ฐ ๋น„์ „ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•œ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ์‰ฝ๊ฒŒ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋Š” API๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. torchvision.models ๋ชจ๋“ˆ์„ ํ™œ์šฉํ•˜์—ฌ RetinaNet ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. RetinaNet์€ torchvision 0.8.0 ์ด์ƒ์—์„œ ์ œ๊ณต๋˜๋ฏ€๋กœ, ์•„๋ž˜ ์ฝ”๋“œ๋ฅผ ํ™œ์šฉํ•˜์—ฌ torchvision ๋ฒ„์ „์„ ๋งž์ถฐ์ค๋‹ˆ๋‹ค.

!pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Requirement already satisfied: torch==1.7.0+cu101 in /usr/local/lib/python3.6/dist-packages (1.7.0+cu101)
Requirement already satisfied: torchvision==0.8.1+cu101 in /usr/local/lib/python3.6/dist-packages (0.8.1+cu101)
Collecting torchaudio==0.7.0
?25l  Downloading https://files.pythonhosted.org/packages/3f/23/6b54106b3de029d3f10cf8debc302491c17630357449c900d6209665b302/torchaudio-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (7.6MB)
     |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 7.6MB 11.1MB/s 
?25hRequirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch==1.7.0+cu101) (0.8)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch==1.7.0+cu101) (3.7.4.3)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==1.7.0+cu101) (1.18.5)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch==1.7.0+cu101) (0.16.0)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.8.1+cu101) (7.0.0)
Installing collected packages: torchaudio
Successfully installed torchaudio-0.7.0
import torchvision
import torch
torchvision.__version__
'0.8.1+cu101'

torchvision.__version__ ๋ช…๋ น์–ด๋ฅผ ํ†ตํ•ด ํ˜„์žฌ cuda 10.1 ๋ฒ„์ „์—์„œ ์ž‘๋™ํ•˜๋Š” torchvision 0.8.1 ๋ฒ„์ „์ด ์„ค์น˜ ๋์Œ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ์œผ๋กœ๋Š” ์•„๋ž˜ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜์—ฌ RetinaNet ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. Face Mask Detection ๋ฐ์ดํ„ฐ์…‹์— 3๊ฐœ์˜ ํด๋ž˜์Šค๊ฐ€ ์กด์žฌํ•˜๋ฏ€๋กœ num_classes ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ 3์œผ๋กœ ์ •์˜ํ•˜๊ณ , ์ „์ด ํ•™์Šต์„ ํ•  ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์— backbone ๊ตฌ์กฐ๋Š” ์‚ฌ์ „ ํ•™์Šต ๋œ ๊ฐ€์ค‘์น˜๋ฅผ, ๊ทธ ์™ธ ๊ฐ€์ค‘์น˜๋Š” ์ดˆ๊ธฐํ™” ์ƒํƒœ๋กœ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. backbone์€ ๊ฐ์ฒด ํƒ์ง€ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์œ ๋ช…ํ•œ COCO ๋ฐ์ดํ„ฐ์…‹์— ์‚ฌ์ „ ํ•™์Šต ๋์Šต๋‹ˆ๋‹ค.

retina = torchvision.models.detection.retinanet_resnet50_fpn(num_classes = 3, pretrained=False, pretrained_backbone = True)

4.5 ์ „์ด ํ•™์Šตยถ

๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์™”์œผ๋ฉด ์•„๋ž˜ ์ฝ”๋“œ๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์ „์ด ํ•™์Šต์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

num_epochs = 30
retina.to(device)
    
# parameters
params = [p for p in retina.parameters() if p.requires_grad] # gradient calculation์ด ํ•„์š”ํ•œ params๋งŒ ์ถ”์ถœ
optimizer = torch.optim.SGD(params, lr=0.005,
                                momentum=0.9, weight_decay=0.0005)

len_dataloader = len(data_loader)

# epoch ๋‹น ์•ฝ 4๋ถ„ ์†Œ์š”
for epoch in range(num_epochs):
    start = time.time()
    retina.train()

    i = 0    
    epoch_loss = 0
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = retina(images, targets) 

        losses = sum(loss for loss in loss_dict.values()) 

        i += 1

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
        epoch_loss += losses 
    print(epoch_loss, f'time: {time.time() - start}')
/usr/local/lib/python3.6/dist-packages/torch/nn/_reduction.py:44: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
  warnings.warn(warning.format(ret))
tensor(285.9670, device='cuda:0', grad_fn=<AddBackward0>) time: 242.22558188438416
tensor(268.1001, device='cuda:0', grad_fn=<AddBackward0>) time: 251.5482075214386
tensor(248.4554, device='cuda:0', grad_fn=<AddBackward0>) time: 248.92862486839294
tensor(233.0612, device='cuda:0', grad_fn=<AddBackward0>) time: 249.69438576698303
tensor(234.2285, device='cuda:0', grad_fn=<AddBackward0>) time: 247.88670659065247
tensor(202.4744, device='cuda:0', grad_fn=<AddBackward0>) time: 249.68517541885376
tensor(172.9739, device='cuda:0', grad_fn=<AddBackward0>) time: 250.47061586380005
tensor(125.8968, device='cuda:0', grad_fn=<AddBackward0>) time: 251.4771168231964
tensor(102.0443, device='cuda:0', grad_fn=<AddBackward0>) time: 251.20848298072815
tensor(88.1749, device='cuda:0', grad_fn=<AddBackward0>) time: 251.144877910614
tensor(78.1594, device='cuda:0', grad_fn=<AddBackward0>) time: 251.8066761493683
tensor(73.6921, device='cuda:0', grad_fn=<AddBackward0>) time: 251.669575214386
tensor(69.6965, device='cuda:0', grad_fn=<AddBackward0>) time: 251.8230264186859
tensor(63.9101, device='cuda:0', grad_fn=<AddBackward0>) time: 252.08272123336792
tensor(56.2955, device='cuda:0', grad_fn=<AddBackward0>) time: 252.18470931053162
tensor(56.2638, device='cuda:0', grad_fn=<AddBackward0>) time: 252.03237462043762
tensor(50.2047, device='cuda:0', grad_fn=<AddBackward0>) time: 252.09569120407104
tensor(45.9254, device='cuda:0', grad_fn=<AddBackward0>) time: 253.205641746521
tensor(44.4599, device='cuda:0', grad_fn=<AddBackward0>) time: 253.05651235580444
tensor(43.9277, device='cuda:0', grad_fn=<AddBackward0>) time: 253.1837260723114
tensor(40.4117, device='cuda:0', grad_fn=<AddBackward0>) time: 253.18618297576904
tensor(39.0882, device='cuda:0', grad_fn=<AddBackward0>) time: 253.36814761161804
tensor(35.3732, device='cuda:0', grad_fn=<AddBackward0>) time: 253.41503262519836
tensor(34.0460, device='cuda:0', grad_fn=<AddBackward0>) time: 252.93738174438477
tensor(35.8844, device='cuda:0', grad_fn=<AddBackward0>) time: 253.25822925567627
tensor(33.1177, device='cuda:0', grad_fn=<AddBackward0>) time: 253.25469851493835
tensor(28.4753, device='cuda:0', grad_fn=<AddBackward0>) time: 253.2648823261261
tensor(30.3831, device='cuda:0', grad_fn=<AddBackward0>) time: 253.4244725704193
tensor(28.0954, device='cuda:0', grad_fn=<AddBackward0>) time: 253.57142424583435
tensor(28.5899, device='cuda:0', grad_fn=<AddBackward0>) time: 253.16517424583435

๋ชจ๋ธ ์žฌ์‚ฌ์šฉ์„ ์œ„ํ•ด ์•„๋ž˜ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์ €์žฅํ•ด์ค๋‹ˆ๋‹ค. torch.save ํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•ด ์ง€์ •ํ•œ ์œ„์น˜์— ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์ €์žฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

torch.save(retina.state_dict(),f'retina_{num_epochs}.pt')
retina.load_state_dict(torch.load(f'retina_{num_epochs}.pt'))
<All keys matched successfully>

ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ๋Š” load_state_dict๊ณผ torch.loadํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. ๋งŒ์•ฝ retina ๋ณ€์ˆ˜๋ฅผ ์ƒˆ๋กญ๊ฒŒ ์ง€์ •ํ–ˆ์„ ๊ฒฝ์šฐ, ํ•ด๋‹น ๋ชจ๋ธ์„ GPU ๋ฉ”๋ชจ๋ฆฌ์— ์˜ฌ๋ ค์ฃผ์–ด์•ผ GPU ์—ฐ์‚ฐ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
retina.to(device)

4.6 ์˜ˆ์ธกยถ

ํ›ˆ๋ จ์ด ๋งˆ๋ฌด๋ฆฌ ๋˜์—ˆ์œผ๋ฉด, ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. test_data_loader์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์™€ ๋ชจ๋ธ์— ๋„ฃ์–ด ํ•™์Šต ํ›„, ์˜ˆ์ธก๋œ ๊ฒฐ๊ณผ์™€ ์‹ค์ œ ๊ฐ’์„ ๊ฐ๊ฐ ์‹œ๊ฐํ™” ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ์„  ์˜ˆ์ธก์— ํ•„์š”ํ•œ ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

def make_prediction(model, img, threshold):
    model.eval()
    preds = model(img)
    for id in range(len(preds)) :
        idx_list = []

        for idx, score in enumerate(preds[id]['scores']) :
            if score > threshold : #threshold ๋„˜๋Š” idx ๊ตฌํ•จ
                idx_list.append(idx)

        preds[id]['boxes'] = preds[id]['boxes'][idx_list]
        preds[id]['labels'] = preds[id]['labels'][idx_list]
        preds[id]['scores'] = preds[id]['scores'][idx_list]


    return preds

make_prediction ํ•จ์ˆ˜์—๋Š” ํ•™์Šต๋œ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ํ™œ์šฉํ•ด ์˜ˆ์ธกํ•˜๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ์ €์žฅ๋ผ ์žˆ์Šต๋‹ˆ๋‹ค. threshold ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์กฐ์ •ํ•ด ์‹ ๋ขฐ๋„๊ฐ€ ์ผ์ • ์ˆ˜์ค€ ์ด์ƒ์˜ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋งŒ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. ๋ณดํ†ต 0.5 ์ด์ƒ์ธ ๊ฐ’์„ ์ตœ์ข… ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ์œผ๋กœ๋Š” for๋ฌธ์„ ํ™œ์šฉํ•ด test_data_loader์— ์žˆ๋Š” ๋ชจ๋“  ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด ์˜ˆ์ธก์„ ์‹ค์‹œํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

from tqdm import tqdm

labels = []
preds_adj_all = []
annot_all = []

for im, annot in tqdm(test_data_loader, position = 0, leave = True):
    im = list(img.to(device) for img in im)
    #annot = [{k: v.to(device) for k, v in t.items()} for t in annot]

    for t in annot:
        labels += t['labels']

    with torch.no_grad():
        preds_adj = make_prediction(retina, im, 0.5)
        preds_adj = [{k: v.to(torch.device('cpu')) for k, v in t.items()} for t in preds_adj]
        preds_adj_all.append(preds_adj)
        annot_all.append(annot)
100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 85/85 [00:24<00:00,  3.47it/s]

tqdm ํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•ด ์ง„ํ–‰ ์ƒํ™ฉ์„ ํ‘œ์ธŒํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ์ธก๋œ ๋ชจ๋“  ๊ฐ’์€ preds_adj_all ๋ณ€์ˆ˜์— ์ €์žฅ๋์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ์œผ๋กœ๋Š” ์‹ค์ œ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค์™€ ์˜ˆ์ธกํ•œ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค์— ๋Œ€ํ•œ ์‹œ๊ฐํ™”๋ฅผ ์ง„ํ–‰ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

nrows = 8
ncols = 2
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4))

batch_i = 0
for im, annot in test_data_loader:
    pos = batch_i * 4 + 1
    for sample_i in range(len(im)) :
        
        img, rects = plot_image_from_output(im[sample_i], annot[sample_i])
        axes[(pos)//2, 1-((pos)%2)].imshow(img)
        for rect in rects:
            axes[(pos)//2, 1-((pos)%2)].add_patch(rect)
        
        img, rects = plot_image_from_output(im[sample_i], preds_adj_all[batch_i][sample_i])
        axes[(pos)//2, 1-((pos+1)%2)].imshow(img)
        for rect in rects:
            axes[(pos)//2, 1-((pos+1)%2)].add_patch(rect)

        pos += 2

    batch_i += 1
    if batch_i == 4:
        break

# xtick, ytick ์ œ๊ฑฐ
for idx, ax in enumerate(axes.flat):
    ax.set_xticks([])
    ax.set_yticks([])

colnames = ['True', 'Pred']

for idx, ax in enumerate(axes[0]):
    ax.set_title(colnames[idx])

plt.tight_layout()
plt.show()
../../_images/Ch4-RetinaNet_35_0.png

for๋ฌธ์„ ํ™œ์šฉํ•ด 4๊ฐœ์˜ batch, ์ด 8๊ฐœ์˜ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์‹ค์ œ ๊ฐ’๊ณผ ์˜ˆ์ธก ๊ฐ’์„ ์‹œ๊ฐํ•ด ๋ณด์•˜์Šต๋‹ˆ๋‹ค. ์™ผ์ชฝ ์—ด์ด ์‹ค์ œ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค์˜ ๋ผ๋ฒจ๊ณผ ์œ„์น˜์ด๋ฉฐ ์˜ค๋ฅธ์ชฝ ์—ด์ด ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ฐ’์ž…๋‹ˆ๋‹ค. ๋งˆ์Šคํฌ ์ฐฉ์šฉ์ž(์ดˆ๋ก์ƒ‰)๋Š” ์ž˜ ํƒ์ง€ํ•˜๊ณ  ์žˆ๋Š” ๊ฒƒ์„ ๊ด€์ธกํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, ๋งˆ์Šคํฌ ๋ฏธ์ฐฉ์šฉ์ž(๋นจ๊ฐ„์ƒ‰)์— ๋Œ€ํ•ด์„œ๋Š” ๊ฐ€๋”์”ฉ ๋งˆ์Šคํฌ๋ฅผ ์˜ฌ๋ฐ”๋ฅด์ง€ ์•Š๊ฒŒ ์ฐฉ์šฉํ•œ ๊ฒƒ(์ฃผํ™ฉ์ƒ‰)์œผ๋กœ ํƒ์ง€ํ•œ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ „๋ฐ˜์ ์ธ ๋ชจ๋ธ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•˜๊ธฐ ์œ„ํ•ด mean Average Precision (mAP)๋ฅผ ์‚ฐ์ถœํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. mAP๋Š” ๊ฐ์ฒด ํƒ์ง€ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•  ๋•Œ ์‚ฌ์šฉํ•˜๋Š” ์ง€ํ‘œ์ž…๋‹ˆ๋‹ค.

๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ์‹œ ๋ถˆ๋Ÿฌ์™”๋˜ Tutorial-Book-Utils ํด๋” ๋‚ด์—๋Š” utils_ObjectDetection.py ํŒŒ์ผ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ํ•ด๋‹น ๋ชจ๋“ˆ ๋‚ด์— ์žˆ๋Š” ํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•ด mAP๋ฅผ ์‚ฐ์ถœํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ์„  utils_ObjectDetection.py ๋ชจ๋“ˆ์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.

%cd Tutorial-Book-Utils/
import utils_ObjectDetection as utils
/content/Tutorial-Book-Utils
sample_metrics = []
for batch_i in range(len(preds_adj_all)):
    sample_metrics += utils.get_batch_statistics(preds_adj_all[batch_i], annot_all[batch_i], iou_threshold=0.5) 

batch ๋ณ„ mAP๋ฅผ ์‚ฐ์ถœํ•˜๋Š”๋ฐ ํ•„์š”ํ•œ ์ •๋ณด๋ฅผ sample_metrics์— ์ €์žฅ ํ›„ ap_per_classํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•ด mAP๋ฅผ ์‚ฐ์ถœํ•ฉ๋‹ˆ๋‹ค.

true_positives, pred_scores, pred_labels = [torch.cat(x, 0) for x in list(zip(*sample_metrics))]  # ๋ฐฐ์น˜๊ฐ€ ์ „๋ถ€ ํ•ฉ์ณ์ง
precision, recall, AP, f1, ap_class = utils.ap_per_class(true_positives, pred_scores, pred_labels, torch.tensor(labels))
mAP = torch.mean(AP)
print(f'mAP : {mAP}')
print(f'AP : {AP}')
mAP : 0.5824690281035101
AP : tensor([0.7684, 0.9188, 0.0603], dtype=torch.float64)

๊ฒฐ๊ณผ๋ฅผ ํ•ด์„ํ•˜๋ฉด 0๋ฒˆ ํด๋ž˜์Šค์ธ ๋งˆ์Šคํฌ๋ฅผ ๋ฏธ์ฐฉ์šฉํ•œ ๊ฐ์ฒด์— ๋Œ€ํ•ด์„œ๋Š” 0.7684 AP๋ฅผ ๋ณด์ด๋ฉฐ 1๋ฒˆ ํด๋ž˜์Šค์ธ ๋งˆ์Šคํฌ ์ฐฉ์šฉ ๊ฐ์ฒด์— ๋Œ€ํ•ด์„œ๋Š” 0.9188 AP๋ฅผ ๋ณด์ด๊ณ , 2๋ฒˆ ํด๋ž˜์Šค์ธ ๋งˆ์Šคํฌ๋ฅผ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ฐฉ์šฉํ•˜์ง€ ์•Š์€ ๊ฐ์ฒด์— ๋Œ€ํ•ด์„œ๋Š” 0.06 AP๋ฅผ ๋ณด์ž…๋‹ˆ๋‹ค.

์ง€๊ธˆ๊นŒ์ง€ RetinaNet์— ๋Œ€ํ•œ ์ „์ด ํ•™์Šต์„ ์‹ค์‹œํ•ด ์˜๋ฃŒ์šฉ ๋งˆ์Šคํฌ ํƒ์ง€ ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด ๋ณด์•˜์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ์žฅ์—์„œ๋Š” Two-Stage Detector์ธ Faster R-CNN์„ ํ™œ์šฉํ•ด ํƒ์ง€ ์„ฑ๋Šฅ์„ ๋†’์—ฌ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.