PYTHON/PYTORCH
[파이토치] torch.argmax 함수
Hitree
2021. 10. 6. 22:46
torch.argmax
torch.argmax(input) → LongTensor
torch.argmax(input, dim, keepdim=False) → LongTensor
이 함수는 input tensor에 있는 모든 element들 중에서 가장 큰 값을 가지는 공간의 인덱스 번호를 반환하는 함수이다.
a = torch.randn(4, 4)
print(a)
output = torch.argmax(a)
print(output)
#### output #####
'''
tensor([[-0.5014, -0.1785, 0.2534, 0.7167],
[-0.7887, 1.0920, 0.5385, -1.1797],
[-1.0129, 0.2337, 0.5757, 0.9139],
[ 1.4672, -1.0605, -0.1456, -0.9329]])
tensor(12)
'''
위 코드에서 볼수 있듯이 1->4 행으로 이동(왼쪽에서 오른쪽으로, 위에서 아래로)하면서 가장 큰 값을 가진 value의 위치를 반환한다.
만약 파라미터로 dimension을 지정해준다면,
a = torch.randn(4, 4)
print(a)
output = torch.argmax(a, dim=0)
print(output)
#### output ####
'''
tensor([[-1.6370, 1.4183, -0.1544, -0.7080],
[ 1.6758, -0.1570, 0.5589, -1.5919],
[-0.3721, 1.6971, 0.1501, -0.0780],
[ 0.6539, 0.1301, 0.6457, 0.8172]])
tensor([1, 2, 3, 3])
'''
dimension (열)을 기준으로 안에 있는 최대값의 위치를 각각 반환하게 횐다. 당연하게도 1을 입력하면 열을 기준으로 최대값의 위치를 각각 반환하게 된다.