소갱
2022. 1. 27. 13:04
torch.nn.CrossEntropyLoss
- Classfication에서 많이 사용된다.
- Softmax 함수를 통해 최종값들을 변환시키고 ([0,1] 총 합은 1이 되도록 하는 것), 그 다음 One-hot Label과의 Cross Entropy를 통해 Loss를 구하게 된다.
# ImageNet에서 학습된 ResNet 18 딥러닝 모델을 불러옴
imagenet_resnet18 = torchvision.models.resnet18(pretrained=True)
print("네트워크 필요 입력 채널 개수", imagenet_resnet18.conv1.weight.shape[1])
print("네트워크 출력 채널 개수 (예측 class type 개수)", imagenet_resnet18.fc.weight.shape[0])
print(imagenet_resnet18)
torchvision.models.resnet18
여기서 imagenet_resnet18을 살펴보게되면
위와 같이 conv1: Conv2d(3, 64...)로 나온 것을 알 수 있다. 이는 입력채널이 3개 필요하다는 것이고(R,G,B) 이를 1개(Grayscale)로 만들기 위해서 필요한 코드는 다음과 같다.
target_model = imagenet_resnet18
target_model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)