본문 바로가기

PYTHON/PYTORCH

[PYTORCH] torch.max 함수 설명

 

 

torch.max


torch.max(input) -> Tensor

torch.max(input, dim, keepdim=False, *, out=None) -> tuple (max, max_indices)


torch.max 함수는 텐서에서 최대값을 구하는 함수이다.

import torch
import torch.nn as nn

data = torch.randn((5, 5))
print(data)
print(torch.max(data))

####
# tensor([[ 1.1328, -0.0392, -0.7076,  0.5610,  0.8010],
#         [-0.0898, -1.4467, -0.7285, -0.1195, -2.1070],
#         [ 0.4547,  1.7739,  0.1664, -1.0242,  0.0474],
#         [ 0.2739,  0.8654, -0.2284,  0.7990,  0.2825],
#         [-0.9779,  1.1102,  1.0023,  0.2493, -2.6588]])

####
# tensor(1.7739)

위처럼 입력된 Tensor의 최대값을 반환하게 된다.

 

또한 함수의 입력으로 dimension에 대한 정보를 포함한다면 dimension의 정보를 반영한 최대값을 반환하게 된다.

import torch
import torch.nn as nn

data = torch.randn((5, 5))
print(data)
print(torch.max(data, dim=1))

####
# tensor([[-1.1575e+00, -1.8878e-01,  4.2871e-01, -2.0162e+00, -3.4006e-01],
#         [ 6.3486e-01, -8.0135e-01,  2.4311e-01,  3.9699e-01, -1.3349e+00],
#         [ 2.6045e-01,  1.2943e+00,  6.0914e-01, -2.7016e-03, -1.0328e+00],
#         [-1.0054e+00,  3.4493e-01,  2.9346e+00, -3.3168e-01,  2.2873e-01],
#         [-9.0424e-02, -9.4272e-01, -4.7905e-01,  5.4701e-01, -2.0194e+00]])
 
####
# torch.return_types.max(
# values=tensor([0.4287, 0.6349, 1.2943, 2.9346, 0.5470]),
# indices=tensor([2, 0, 1, 2, 3]))

위처럼 dimension의 정보가 포함되었을때는 dimension을 기준으로 최대값을 추출해주고 이에 따른 indices 정보를 추가적으로 반환해준다. 그렇기 때문에 이때는 tuple 형태로 반환된다.