torch.max
torch.max(input) -> Tensor
torch.max(input, dim, keepdim=False, *, out=None) -> tuple (max, max_indices)
torch.max 함수는 텐서에서 최대값을 구하는 함수이다.
import torch
import torch.nn as nn
data = torch.randn((5, 5))
print(data)
print(torch.max(data))
####
# tensor([[ 1.1328, -0.0392, -0.7076, 0.5610, 0.8010],
# [-0.0898, -1.4467, -0.7285, -0.1195, -2.1070],
# [ 0.4547, 1.7739, 0.1664, -1.0242, 0.0474],
# [ 0.2739, 0.8654, -0.2284, 0.7990, 0.2825],
# [-0.9779, 1.1102, 1.0023, 0.2493, -2.6588]])
####
# tensor(1.7739)
위처럼 입력된 Tensor의 최대값을 반환하게 된다.
또한 함수의 입력으로 dimension에 대한 정보를 포함한다면 dimension의 정보를 반영한 최대값을 반환하게 된다.
import torch
import torch.nn as nn
data = torch.randn((5, 5))
print(data)
print(torch.max(data, dim=1))
####
# tensor([[-1.1575e+00, -1.8878e-01, 4.2871e-01, -2.0162e+00, -3.4006e-01],
# [ 6.3486e-01, -8.0135e-01, 2.4311e-01, 3.9699e-01, -1.3349e+00],
# [ 2.6045e-01, 1.2943e+00, 6.0914e-01, -2.7016e-03, -1.0328e+00],
# [-1.0054e+00, 3.4493e-01, 2.9346e+00, -3.3168e-01, 2.2873e-01],
# [-9.0424e-02, -9.4272e-01, -4.7905e-01, 5.4701e-01, -2.0194e+00]])
####
# torch.return_types.max(
# values=tensor([0.4287, 0.6349, 1.2943, 2.9346, 0.5470]),
# indices=tensor([2, 0, 1, 2, 3]))
위처럼 dimension의 정보가 포함되었을때는 dimension을 기준으로 최대값을 추출해주고 이에 따른 indices 정보를 추가적으로 반환해준다. 그렇기 때문에 이때는 tuple 형태로 반환된다.
'PYTHON > PYTORCH' 카테고리의 다른 글
[PYTORCH] Pytorch Lightning이란? (0) | 2021.12.24 |
---|---|
[PYTORCH] 2개 이상의 Loss를 사용할때 주의할 점 (1) | 2021.12.23 |
[파이토치] torch.argmax 함수 (0) | 2021.10.06 |
[파이토치] torch.mm 함수 (0) | 2021.10.06 |
[fairseq] 설치 시 오류 'enum'오류 (0) | 2021.09.09 |