PYTHON/PYTORCH
[파이토치] torch.mm 함수
Hitree
2021. 10. 6. 21:55
torch.mm
torch.mm(input, mat2, *, out=None) → Tensor
mm은 input과 mat2에 대해서 matrix multiplication을 수행하는 함수이다. 이 함수는 broadcast되지 않는 것이 특징이다. 만약 broadcasting된 것을 원한다면 torch.matul 함수를 사용하여아 한다.
torch.mm(
torch.tensor([[2, 2], [2, 2]]),
torch.tensor([[2, 2], [2, 2]])
)
#### output ####
'''
tensor([[8, 8],
[8, 8]])
'''
간단하게 torch를 import 하고 코드를 돌려보면 mm의 동작을 살펴볼 수 있다.