Model¶
PyTorch 에서 신경망 모델을 구성하는 방법에 대해 안내합니다.
담당자: 권지현 님
최종수정일: 21-09-29
본 자료는 가짜연구소 3기 Pytorch guide 크루 활동으로 작성됨
1 Model Class¶
1.1 Class 란¶
1.2 Class 정의¶
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
2. Layer¶
2.1 nn.Flatten¶
flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.size())
2.2 nn.Linear¶
layer1 = nn.Linear(in_features=28*28, out_features=20)
hidden1 = layer1(flat_image)
print(hidden1.size())
2.3 nn.ReLU¶
print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")
2.4 nn.Sequential¶
seq_modules = nn.Sequential(
flatten,
layer1,
nn.ReLU(),
nn.Linear(20, 10)
)
2.5 nn.Softmax¶
softmax = nn.Softmax(dim=1)
pred_probab = softmax(logits)
3. Torch.nn API¶
https://pytorch.org/docs/stable/nn.html