-
로지스틱 회귀부스트캠프 Ai tech/2주차 2022. 1. 30. 18:26
로지스틱 회귀
- 둘 중 하나를 결정하는 문제를 이진 분류(Binary Classfication)라고 하는데, 이진 분류를 풀기 위한 대표적인 알고리즘으로 로지스틱 회귀(Logistic Regression)이 있다.
- 선형 회귀 때의 H(x) = Wx + b가 아니라, S자 모양의 그래프를 만들 수 있는 특정 함수 f를 추가적으로 사용하여 나타낸다. H(x) = f(Wx+b) 여기서 사용되는 f = 시그모이드 함수
x = np.arange(-5.0, 5.0, 0.1) y1 = sigmoid(x * 0.5) y2 = sigmoid(x) y3 = sigmoid(x * 2) plt.plot(x, y1, 'r', linestyle='--') plt.plot(x, y2, 'g') plt.plot(x, y3, 'b', linestyle='--') plt.plot([0,0],[1.0,0.0], ':') plt.title('Sigmoid Function') plt.show()
sigmoid 함수에서 x의 계수가 커질수록 경사폭이 커지는것을 알 수 있다.
구현
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import numpy as np x_data = [[1, 2], [2, 3], [3, 1], [4, 3], [5, 3], [6, 2]] y_data = [[0], [0], [0], [1], [1], [1]] x_train = torch.FloatTensor(x_data) y_train = torch.FloatTensor(y_data) optimizer = optim.SGD(model.parameters(), lr = 1) nb_epochs = 1000 for epoch in range(nb_epochs + 1): hypothesis = model(x_train) cost = F.binary_cross_entropy(hypothesis, y_train) optimizer.zero_grad() cost.backward() optimizer.step() if epoch % 100 == 0: prediction = hypothesis >= torch.tensor([0.5]) # 예측값이 0.5를 넘으면 True로 간주 correct_prediction = prediction.float() == y_train # 실제값과 일치하는 경우만 True로 간주 accuracy = correct_prediction.sum().item() / len(correct_prediction) # 정확도를 계산 print('Epoch {:4d}/{} Cost: {:.6f} Accuracy {:2.2f}%'.format( # 각 에포크마다 정확도를 출력 epoch, nb_epochs, cost.item(), accuracy * 100, ))
결과
Epoch 0/1000 Cost: 0.527002 Accuracy 83.33% Epoch 100/1000 Cost: 0.133369 Accuracy 100.00% Epoch 200/1000 Cost: 0.080169 Accuracy 100.00% Epoch 300/1000 Cost: 0.057659 Accuracy 100.00% Epoch 400/1000 Cost: 0.045154 Accuracy 100.00% Epoch 500/1000 Cost: 0.037163 Accuracy 100.00% Epoch 600/1000 Cost: 0.031602 Accuracy 100.00% Epoch 700/1000 Cost: 0.027503 Accuracy 100.00% Epoch 800/1000 Cost: 0.024353 Accuracy 100.00% Epoch 900/1000 Cost: 0.021855 Accuracy 100.00% Epoch 1000/1000 Cost: 0.019825 Accuracy 100.00%
'부스트캠프 Ai tech > 2주차' 카테고리의 다른 글
10일 (2) 2022.01.28 9일 (0) 2022.01.27 8일 (0) 2022.01.26 7일 - 과제 정리 (5) 2022.01.26 6일 - torch.gather (0) 2022.01.24