4. RetinaNetΒΆ
3μ₯μμλ μ 곡λ λ°μ΄ν°μ augmentationμ κ°νλ λ°©λ²κ³Ό λ°μ΄ν°μ ν΄λμ€λ₯Ό λ§λλ λ°©λ²μ νμΈνμ΅λλ€. μ΄λ² μ₯μμλ torchvisionμμ μ 곡νλ one-stage λͺ¨λΈμΈ RetinaNetμ νμ©ν΄ μλ£μ© λ§μ€ν¬ κ²μΆ λͺ¨λΈμ ꡬμΆν΄λ³΄κ² μ΅λλ€.
4.1μ λΆν° 4.3μ κΉμ§λ 2μ₯κ³Ό 3μ₯μμ νμΈν λ΄μ©μ λ°νμΌλ‘ λ°μ΄ν°λ₯Ό λΆλ¬μ€κ³ νλ ¨μ©, μνμ© λ°μ΄ν°λ‘ λλ ν λ°μ΄ν°μ ν΄λμ€λ₯Ό μ μνκ² μ΅λλ€. 4.4μ μμλ torchvision APIλ₯Ό νμ©νμ¬ μ¬μ νλ ¨λ λͺ¨λΈμ λΆλ¬μ€κ² μ΅λλ€. 4.5μ μμλ μ μ΄ νμ΅μ ν΅ν΄ λͺ¨λΈ νμ΅μ μ§νν ν 4.6μ μμ μμΈ‘κ° μ°μΆ λ° λͺ¨λΈ μ±λ₯μ νμΈν΄λ³΄κ² μ΅λλ€.
4.1 λ°μ΄ν° λ€μ΄λ‘λΒΆ
λͺ¨λΈλ§ μ€μ΅μ μν΄ 2.1μ μ λμ¨ μ½λλ₯Ό νμ©νμ¬ λ°μ΄ν°λ₯Ό λΆλ¬μ€κ² μ΅λλ€.
Cloning into 'Tutorial-Book-Utils'...
remote: Enumerating objects: 12, done.
remote: Counting objects: 100% (12/12), done.
remote: Compressing objects: 100% (11/11), done.
remote: Total 12 (delta 1), reused 2 (delta 0), pack-reused 0
Unpacking objects: 100% (12/12), done.
Face Mask Detection.zip is done!
4.2 λ°μ΄ν° λΆλ¦¬ΒΆ
3.3μ μμ νμΈν λ°μ΄ν° λΆλ¦¬ λ°©λ²μ νμ©νμ¬ λ°μ΄ν°λ₯Ό λΆλ¦¬νκ² μ΅λλ€.
import os
import random
import numpy as np
import shutil
print(len(os.listdir('annotations')))
print(len(os.listdir('images')))
!mkdir test_images
!mkdir test_annotations
random.seed(1234)
idx = random.sample(range(853), 170)
for img in np.array(sorted(os.listdir('images')))[idx]:
shutil.move('images/'+img, 'test_images/'+img)
for annot in np.array(sorted(os.listdir('annotations')))[idx]:
shutil.move('annotations/'+annot, 'test_annotations/'+annot)
print(len(os.listdir('annotations')))
print(len(os.listdir('images')))
print(len(os.listdir('test_annotations')))
print(len(os.listdir('test_images')))
4.3 λ°μ΄ν°μ ν΄λμ€ μ μΒΆ
νμ΄ν μΉ λͺ¨λΈμ νμ΅μν€κΈ° μν΄μ λ°μ΄ν°μ
ν΄λμ€λ₯Ό μ μν΄μΌ ν©λλ€. torchvisionμμ μ 곡νλ κ°μ²΄ νμ§ λͺ¨λΈμ νμ΅μν€κΈ° μν λ°μ΄ν°μ
ν΄λμ€μ __getitem__
λ©μλλ μ΄λ―Έμ§ νμΌκ³Ό λ°μ΄λ© λ°μ€ μ’νλ₯Ό λ°ν ν©λλ€. λ°μ΄ν°μ
ν΄λμ€λ₯Ό 3μ₯μμ νμ©ν μ½λλ₯Ό μμ©ν΄ μλμ κ°μ΄ μ μ νκ² μ΅λλ€.
import os
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
from bs4 import BeautifulSoup
from PIL import Image
import cv2
import numpy as np
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
from matplotlib import pyplot as plt
import os
def generate_box(obj):
xmin = float(obj.find('xmin').text)
ymin = float(obj.find('ymin').text)
xmax = float(obj.find('xmax').text)
ymax = float(obj.find('ymax').text)
return [xmin, ymin, xmax, ymax]
def generate_label(obj):
if obj.find('name').text == "with_mask":
return 1
elif obj.find('name').text == "mask_weared_incorrect":
return 2
return 0
def generate_target(file):
with open(file) as f:
data = f.read()
soup = BeautifulSoup(data, "html.parser")
objects = soup.find_all("object")
num_objs = len(objects)
boxes = []
labels = []
for i in objects:
boxes.append(generate_box(i))
labels.append(generate_label(i))
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
return target
def plot_image_from_output(img, annotation):
img = img.cpu().permute(1,2,0)
rects = []
for idx in range(len(annotation["boxes"])):
xmin, ymin, xmax, ymax = annotation["boxes"][idx]
if annotation['labels'][idx] == 0 :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
elif annotation['labels'][idx] == 1 :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
else :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='orange',facecolor='none')
rects.append(rect)
return img, rects
class MaskDataset(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.imgs = list(sorted(os.listdir(self.path)))
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
file_image = self.imgs[idx]
file_label = self.imgs[idx][:-3] + 'xml'
img_path = os.path.join(self.path, file_image)
if 'test' in self.path:
label_path = os.path.join("test_annotations/", file_label)
else:
label_path = os.path.join("annotations/", file_label)
img = Image.open(img_path).convert("RGB")
target = generate_target(label_path)
to_tensor = torchvision.transforms.ToTensor()
if self.transform:
img, transform_target = self.transform(np.array(img), np.array(target['boxes']))
target['boxes'] = torch.as_tensor(transform_target)
# tensorλ‘ λ³κ²½
img = to_tensor(img)
return img, target
def collate_fn(batch):
return tuple(zip(*batch))
dataset = MaskDataset('images/')
test_dataset = MaskDataset('test_images/')
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, collate_fn=collate_fn)
μ΅μ’
μ μΌλ‘ νλ ¨μ© λ°μ΄ν°μ μνμ© λ°μ΄ν°λ₯Ό batch λ¨μλ‘ λΆλ¬μ¬ μ μκ² torch.utils.data.DataLoader
ν¨μλ₯Ό νμ©ν΄ data_loader
μ test_data_loader
λ₯Ό κ°κ° μ μν©λλ€.
4.4 λͺ¨λΈ λΆλ¬μ€κΈ°ΒΆ
torchvision
μμλ κ°μ’
μ»΄ν¨ν° λΉμ λ¬Έμ λ₯Ό ν΄κ²°νκΈ° μν λ₯λ¬λ λͺ¨λΈμ μ½κ² λΆλ¬μ¬ μ μλ APIλ₯Ό μ 곡ν©λλ€. torchvision.models
λͺ¨λμ νμ©νμ¬ RetinaNet λͺ¨λΈμ λΆλ¬μ€λλ‘ νκ² μ΅λλ€. RetinaNetμ torchvision
0.8.0 μ΄μμμ μ 곡λλ―λ‘, μλ μ½λλ₯Ό νμ©νμ¬ torchvision
λ²μ μ λ§μΆ°μ€λλ€.
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Requirement already satisfied: torch==1.7.0+cu101 in /usr/local/lib/python3.6/dist-packages (1.7.0+cu101)
Requirement already satisfied: torchvision==0.8.1+cu101 in /usr/local/lib/python3.6/dist-packages (0.8.1+cu101)
Collecting torchaudio==0.7.0
?25l Downloading https://files.pythonhosted.org/packages/3f/23/6b54106b3de029d3f10cf8debc302491c17630357449c900d6209665b302/torchaudio-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (7.6MB)
|ββββββββββββββββββββββββββββββββ| 7.6MB 11.1MB/s
?25hRequirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch==1.7.0+cu101) (0.8)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch==1.7.0+cu101) (3.7.4.3)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==1.7.0+cu101) (1.18.5)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch==1.7.0+cu101) (0.16.0)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.8.1+cu101) (7.0.0)
Installing collected packages: torchaudio
Successfully installed torchaudio-0.7.0
torchvision.__version__
λͺ
λ Ήμ΄λ₯Ό ν΅ν΄ νμ¬ cuda 10.1 λ²μ μμ μλνλ torchvision
0.8.1 λ²μ μ΄ μ€μΉ λμμ νμΈν μ μμ΅λλ€. λ€μμΌλ‘λ μλ μ½λλ₯Ό μ€ννμ¬ RetinaNet λͺ¨λΈμ λΆλ¬μ΅λλ€. Face Mask Detection λ°μ΄ν°μ
μ 3κ°μ ν΄λμ€κ° μ‘΄μ¬νλ―λ‘ num_classes 맀κ°λ³μλ₯Ό 3μΌλ‘ μ μνκ³ , μ μ΄ νμ΅μ ν κ²μ΄κΈ° λλ¬Έμ backbone ꡬ쑰λ μ¬μ νμ΅ λ κ°μ€μΉλ₯Ό, κ·Έ μΈ κ°μ€μΉλ μ΄κΈ°ν μνλ‘ κ°μ Έμ΅λλ€. backboneμ κ°μ²΄ νμ§ λ°μ΄ν°μ
μΌλ‘ μ λͺ
ν COCO λ°μ΄ν°μ
μ μ¬μ νμ΅ λμ΅λλ€.
4.5 μ μ΄ νμ΅ΒΆ
λͺ¨λΈμ λΆλ¬μμΌλ©΄ μλ μ½λλ₯Ό νμ©νμ¬ μ μ΄ νμ΅μ μ§νν©λλ€.
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_epochs = 30
retina.to(device)
# parameters
params = [p for p in retina.parameters() if p.requires_grad] # gradient calculationμ΄ νμν paramsλ§ μΆμΆ
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
len_dataloader = len(data_loader)
# epoch λΉ μ½ 4λΆ μμ
for epoch in range(num_epochs):
start = time.time()
retina.train()
i = 0
epoch_loss = 0
for images, targets in data_loader:
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = retina(images, targets)
losses = sum(loss for loss in loss_dict.values())
i += 1
optimizer.zero_grad()
losses.backward()
optimizer.step()
epoch_loss += losses
print(epoch_loss, f'time: {time.time() - start}')
tensor(285.9670, device='cuda:0', grad_fn=<AddBackward0>) time: 242.22558188438416
tensor(268.1001, device='cuda:0', grad_fn=<AddBackward0>) time: 251.5482075214386
tensor(248.4554, device='cuda:0', grad_fn=<AddBackward0>) time: 248.92862486839294
tensor(233.0612, device='cuda:0', grad_fn=<AddBackward0>) time: 249.69438576698303
tensor(234.2285, device='cuda:0', grad_fn=<AddBackward0>) time: 247.88670659065247
tensor(202.4744, device='cuda:0', grad_fn=<AddBackward0>) time: 249.68517541885376
tensor(172.9739, device='cuda:0', grad_fn=<AddBackward0>) time: 250.47061586380005
tensor(125.8968, device='cuda:0', grad_fn=<AddBackward0>) time: 251.4771168231964
tensor(102.0443, device='cuda:0', grad_fn=<AddBackward0>) time: 251.20848298072815
tensor(88.1749, device='cuda:0', grad_fn=<AddBackward0>) time: 251.144877910614
tensor(78.1594, device='cuda:0', grad_fn=<AddBackward0>) time: 251.8066761493683
tensor(73.6921, device='cuda:0', grad_fn=<AddBackward0>) time: 251.669575214386
tensor(69.6965, device='cuda:0', grad_fn=<AddBackward0>) time: 251.8230264186859
tensor(63.9101, device='cuda:0', grad_fn=<AddBackward0>) time: 252.08272123336792
tensor(56.2955, device='cuda:0', grad_fn=<AddBackward0>) time: 252.18470931053162
tensor(56.2638, device='cuda:0', grad_fn=<AddBackward0>) time: 252.03237462043762
tensor(50.2047, device='cuda:0', grad_fn=<AddBackward0>) time: 252.09569120407104
tensor(45.9254, device='cuda:0', grad_fn=<AddBackward0>) time: 253.205641746521
tensor(44.4599, device='cuda:0', grad_fn=<AddBackward0>) time: 253.05651235580444
tensor(43.9277, device='cuda:0', grad_fn=<AddBackward0>) time: 253.1837260723114
tensor(40.4117, device='cuda:0', grad_fn=<AddBackward0>) time: 253.18618297576904
tensor(39.0882, device='cuda:0', grad_fn=<AddBackward0>) time: 253.36814761161804
tensor(35.3732, device='cuda:0', grad_fn=<AddBackward0>) time: 253.41503262519836
tensor(34.0460, device='cuda:0', grad_fn=<AddBackward0>) time: 252.93738174438477
tensor(35.8844, device='cuda:0', grad_fn=<AddBackward0>) time: 253.25822925567627
tensor(33.1177, device='cuda:0', grad_fn=<AddBackward0>) time: 253.25469851493835
tensor(28.4753, device='cuda:0', grad_fn=<AddBackward0>) time: 253.2648823261261
tensor(30.3831, device='cuda:0', grad_fn=<AddBackward0>) time: 253.4244725704193
tensor(28.0954, device='cuda:0', grad_fn=<AddBackward0>) time: 253.57142424583435
tensor(28.5899, device='cuda:0', grad_fn=<AddBackward0>) time: 253.16517424583435
λͺ¨λΈ μ¬μ¬μ©μ μν΄ μλ μ½λλ₯Ό μ€ννμ¬ νμ΅λ κ°μ€μΉλ₯Ό μ μ₯ν΄μ€λλ€. torch.save
ν¨μλ₯Ό νμ©ν΄ μ§μ ν μμΉμ νμ΅λ κ°μ€μΉλ₯Ό μ μ₯ν μ μμ΅λλ€.
νμ΅λ κ°μ€μΉλ₯Ό λΆλ¬μ¬ λλ load_state_dict
κ³Ό torch.load
ν¨μλ₯Ό μ¬μ©νλ©΄ λ©λλ€. λ§μ½ retina λ³μλ₯Ό μλ‘κ² μ§μ νμ κ²½μ°, ν΄λΉ λͺ¨λΈμ GPU λ©λͺ¨λ¦¬μ μ¬λ €μ£Όμ΄μΌ GPU μ°μ°μ΄ κ°λ₯ν©λλ€.
4.6 μμΈ‘ΒΆ
νλ ¨μ΄ λ§λ¬΄λ¦¬ λμμΌλ©΄, μμΈ‘ κ²°κ³Όλ₯Ό νμΈνλλ‘ νκ² μ΅λλ€. test_data_loaderμμ λ°μ΄ν°λ₯Ό λΆλ¬μ λͺ¨λΈμ λ£μ΄ νμ΅ ν, μμΈ‘λ κ²°κ³Όμ μ€μ κ°μ κ°κ° μκ°ν ν΄λ³΄λλ‘ νκ² μ΅λλ€. μ°μ μμΈ‘μ νμν ν¨μλ₯Ό μ μνκ² μ΅λλ€.
def make_prediction(model, img, threshold):
model.eval()
preds = model(img)
for id in range(len(preds)) :
idx_list = []
for idx, score in enumerate(preds[id]['scores']) :
if score > threshold : #threshold λλ idx ꡬν¨
idx_list.append(idx)
preds[id]['boxes'] = preds[id]['boxes'][idx_list]
preds[id]['labels'] = preds[id]['labels'][idx_list]
preds[id]['scores'] = preds[id]['scores'][idx_list]
return preds
make_prediction
ν¨μμλ νμ΅λ λ₯λ¬λ λͺ¨λΈμ νμ©ν΄ μμΈ‘νλ μκ³ λ¦¬μ¦μ΄ μ μ₯λΌ μμ΅λλ€. threshold
νλΌλ―Έν°λ₯Ό μ‘°μ ν΄ μ λ’°λκ° μΌμ μμ€ μ΄μμ λ°μ΄λ© λ°μ€λ§ μ νν©λλ€. λ³΄ν΅ 0.5 μ΄μμΈ κ°μ μ΅μ’
μ νν©λλ€. λ€μμΌλ‘λ forλ¬Έμ νμ©ν΄ test_data_loader
μ μλ λͺ¨λ λ°μ΄ν°μ λν΄ μμΈ‘μ μ€μνκ² μ΅λλ€.
from tqdm import tqdm
labels = []
preds_adj_all = []
annot_all = []
for im, annot in tqdm(test_data_loader, position = 0, leave = True):
im = list(img.to(device) for img in im)
#annot = [{k: v.to(device) for k, v in t.items()} for t in annot]
for t in annot:
labels += t['labels']
with torch.no_grad():
preds_adj = make_prediction(retina, im, 0.5)
preds_adj = [{k: v.to(torch.device('cpu')) for k, v in t.items()} for t in preds_adj]
preds_adj_all.append(preds_adj)
annot_all.append(annot)
tqdm
ν¨μλ₯Ό νμ©ν΄ μ§ν μν©μ νμΈνκ³ μμ΅λλ€. μμΈ‘λ λͺ¨λ κ°μ preds_adj_all
λ³μμ μ μ₯λμ΅λλ€. λ€μμΌλ‘λ μ€μ λ°μ΄λ© λ°μ€μ μμΈ‘ν λ°μ΄λ© λ°μ€μ λν μκ°νλ₯Ό μ§νν΄λ³΄κ² μ΅λλ€.
nrows = 8
ncols = 2
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4))
batch_i = 0
for im, annot in test_data_loader:
pos = batch_i * 4 + 1
for sample_i in range(len(im)) :
img, rects = plot_image_from_output(im[sample_i], annot[sample_i])
axes[(pos)//2, 1-((pos)%2)].imshow(img)
for rect in rects:
axes[(pos)//2, 1-((pos)%2)].add_patch(rect)
img, rects = plot_image_from_output(im[sample_i], preds_adj_all[batch_i][sample_i])
axes[(pos)//2, 1-((pos+1)%2)].imshow(img)
for rect in rects:
axes[(pos)//2, 1-((pos+1)%2)].add_patch(rect)
pos += 2
batch_i += 1
if batch_i == 4:
break
# xtick, ytick μ κ±°
for idx, ax in enumerate(axes.flat):
ax.set_xticks([])
ax.set_yticks([])
colnames = ['True', 'Pred']
for idx, ax in enumerate(axes[0]):
ax.set_title(colnames[idx])
plt.tight_layout()
plt.show()

forλ¬Έμ νμ©ν΄ 4κ°μ batch, μ΄ 8κ°μ μ΄λ―Έμ§μ λν μ€μ κ°κ³Ό μμΈ‘ κ°μ μκ°ν΄ 보μμ΅λλ€. μΌμͺ½ μ΄μ΄ μ€μ λ°μ΄λ© λ°μ€μ λΌλ²¨κ³Ό μμΉμ΄λ©° μ€λ₯Έμͺ½ μ΄μ΄ λͺ¨λΈμ μμΈ‘ κ°μ λλ€. λ§μ€ν¬ μ°©μ©μ(μ΄λ‘μ)λ μ νμ§νκ³ μλ κ²μ κ΄μΈ‘νκ³ μμΌλ©°, λ§μ€ν¬ λ―Έμ°©μ©μ(λΉ¨κ°μ)μ λν΄μλ κ°λμ© λ§μ€ν¬λ₯Ό μ¬λ°λ₯΄μ§ μκ² μ°©μ©ν κ²(μ£Όν©μ)μΌλ‘ νμ§ν κ²μ λ³Ό μ μμ΅λλ€. μ λ°μ μΈ λͺ¨λΈ μ±λ₯μ νκ°νκΈ° μν΄ mean Average Precision (mAP)λ₯Ό μ°μΆν΄λ³΄κ² μ΅λλ€. mAPλ κ°μ²΄ νμ§ λͺ¨λΈμ νκ°ν λ μ¬μ©νλ μ§νμ λλ€.
λ°μ΄ν° λ€μ΄λ‘λμ λΆλ¬μλ Tutorial-Book-Utils ν΄λ λ΄μλ utils_ObjectDetection.py νμΌμ΄ μμ΅λλ€. ν΄λΉ λͺ¨λ λ΄μ μλ ν¨μλ₯Ό νμ©ν΄ mAPλ₯Ό μ°μΆν΄λ³΄κ² μ΅λλ€. μ°μ utils_ObjectDetection.py λͺ¨λμ λΆλ¬μ΅λλ€.
batch λ³ mAPλ₯Ό μ°μΆνλλ° νμν μ 보λ₯Ό sample_metrics
μ μ μ₯ ν ap_per_class
ν¨μλ₯Ό νμ©ν΄ mAPλ₯Ό μ°μΆν©λλ€.
true_positives, pred_scores, pred_labels = [torch.cat(x, 0) for x in list(zip(*sample_metrics))] # λ°°μΉκ° μ λΆ ν©μ³μ§
precision, recall, AP, f1, ap_class = utils.ap_per_class(true_positives, pred_scores, pred_labels, torch.tensor(labels))
mAP = torch.mean(AP)
print(f'mAP : {mAP}')
print(f'AP : {AP}')
κ²°κ³Όλ₯Ό ν΄μνλ©΄ 0λ² ν΄λμ€μΈ λ§μ€ν¬λ₯Ό λ―Έμ°©μ©ν κ°μ²΄μ λν΄μλ 0.7684 APλ₯Ό 보μ΄λ©° 1λ² ν΄λμ€μΈ λ§μ€ν¬ μ°©μ© κ°μ²΄μ λν΄μλ 0.9188 APλ₯Ό 보μ΄κ³ , 2λ² ν΄λμ€μΈ λ§μ€ν¬λ₯Ό μ¬λ°λ₯΄κ² μ°©μ©νμ§ μμ κ°μ²΄μ λν΄μλ 0.06 APλ₯Ό 보μ λλ€.
μ§κΈκΉμ§ RetinaNetμ λν μ μ΄ νμ΅μ μ€μν΄ μλ£μ© λ§μ€ν¬ νμ§ λͺ¨λΈμ λ§λ€μ΄ 보μμ΅λλ€. λ€μ μ₯μμλ Two-Stage DetectorμΈ Faster R-CNNμ νμ©ν΄ νμ§ μ±λ₯μ λμ¬λ³΄κ² μ΅λλ€.