torch.argmax
torch.argmax(input) → LongTensor
torch.argmax(input, dim, keepdim=False) → LongTensor
이 함수는 input tensor에 있는 모든 element들 중에서 가장 큰 값을 가지는 공간의 인덱스 번호를 반환하는 함수이다.
a = torch.randn(4, 4)
print(a)
output = torch.argmax(a)
print(output)
#### output #####
'''
tensor([[-0.5014, -0.1785, 0.2534, 0.7167],
[-0.7887, 1.0920, 0.5385, -1.1797],
[-1.0129, 0.2337, 0.5757, 0.9139],
[ 1.4672, -1.0605, -0.1456, -0.9329]])
tensor(12)
'''
위 코드에서 볼수 있듯이 1->4 행으로 이동(왼쪽에서 오른쪽으로, 위에서 아래로)하면서 가장 큰 값을 가진 value의 위치를 반환한다.
만약 파라미터로 dimension을 지정해준다면,
a = torch.randn(4, 4)
print(a)
output = torch.argmax(a, dim=0)
print(output)
#### output ####
'''
tensor([[-1.6370, 1.4183, -0.1544, -0.7080],
[ 1.6758, -0.1570, 0.5589, -1.5919],
[-0.3721, 1.6971, 0.1501, -0.0780],
[ 0.6539, 0.1301, 0.6457, 0.8172]])
tensor([1, 2, 3, 3])
'''
dimension (열)을 기준으로 안에 있는 최대값의 위치를 각각 반환하게 횐다. 당연하게도 1을 입력하면 열을 기준으로 최대값의 위치를 각각 반환하게 된다.
'PYTHON > PYTORCH' 카테고리의 다른 글
[PYTORCH] 2개 이상의 Loss를 사용할때 주의할 점 (1) | 2021.12.23 |
---|---|
[PYTORCH] torch.max 함수 설명 (0) | 2021.12.22 |
[파이토치] torch.mm 함수 (0) | 2021.10.06 |
[fairseq] 설치 시 오류 'enum'오류 (0) | 2021.09.09 |
[TorchAudio] Transformations 알아보기 (0) | 2021.01.11 |