본문 바로가기

PYTHON/PYTORCH

[파이토치] torch.argmax 함수

torch.argmax


torch.argmax(input) → LongTensor

torch.argmax(input, dim, keepdim=False) → LongTensor


이 함수는 input tensor에 있는 모든 element들 중에서 가장 큰 값을 가지는 공간의 인덱스 번호를 반환하는 함수이다. 

 

a = torch.randn(4, 4)
print(a)
output = torch.argmax(a)
print(output)

#### output #####
'''
tensor([[-0.5014, -0.1785,  0.2534,  0.7167],
        [-0.7887,  1.0920,  0.5385, -1.1797],
        [-1.0129,  0.2337,  0.5757,  0.9139],
        [ 1.4672, -1.0605, -0.1456, -0.9329]])
tensor(12)
'''

위 코드에서 볼수 있듯이 1->4 행으로 이동(왼쪽에서 오른쪽으로, 위에서 아래로)하면서 가장 큰 값을 가진 value의 위치를 반환한다.

 

만약 파라미터로 dimension을 지정해준다면, 

a = torch.randn(4, 4)
print(a)
output = torch.argmax(a, dim=0)
print(output)

#### output ####
'''
tensor([[-1.6370,  1.4183, -0.1544, -0.7080],
        [ 1.6758, -0.1570,  0.5589, -1.5919],
        [-0.3721,  1.6971,  0.1501, -0.0780],
        [ 0.6539,  0.1301,  0.6457,  0.8172]])
tensor([1, 2, 3, 3])
'''

dimension (열)을 기준으로 안에 있는 최대값의 위치를 각각 반환하게 횐다. 당연하게도 1을 입력하면 열을 기준으로 최대값의 위치를 각각 반환하게 된다.