torch.mm
torch.mm(input, mat2, *, out=None) → Tensor
mm은 input과 mat2에 대해서 matrix multiplication을 수행하는 함수이다. 이 함수는 broadcast되지 않는 것이 특징이다. 만약 broadcasting된 것을 원한다면 torch.matul 함수를 사용하여아 한다.
torch.mm(
torch.tensor([[2, 2], [2, 2]]),
torch.tensor([[2, 2], [2, 2]])
)
#### output ####
'''
tensor([[8, 8],
[8, 8]])
'''
간단하게 torch를 import 하고 코드를 돌려보면 mm의 동작을 살펴볼 수 있다.
'PYTHON > PYTORCH' 카테고리의 다른 글
[PYTORCH] torch.max 함수 설명 (0) | 2021.12.22 |
---|---|
[파이토치] torch.argmax 함수 (0) | 2021.10.06 |
[fairseq] 설치 시 오류 'enum'오류 (0) | 2021.09.09 |
[TorchAudio] Transformations 알아보기 (0) | 2021.01.11 |
[파이토치] Melspectrogram 추출하기 (0) | 2021.01.10 |