본문 바로가기

PYTHON/MATPLOTLIB

[Matplotlib] HeatMap 그리기

Introduction


Python에서는 정말 많은 라이브러리들이 있다. 그 중에서도 머신러닝/딥러닝을 하는 사람들이라면 Heatmap을 한번쯤 써보지 않나 싶다. Heatmap은 '다양한 정보를 일정한 이미지 위에서 열 분포 형태로 visualization 한 것'을 의미한다. 

 

 

Code


import matplotlib.pyplot as plt
import numpy as np
import torch

data = torch.rand(128, 256) # freq, time
# data = data.permute(1, 0)
print(data.size())

data = data.numpy()
plt.figure(figsize=(10, 9))
plt.matshow(data)
plt.colorbar()

구현 코드는 단순하다. 필자는 torch를 정말 많이 사용하기 때문에 torch를 이용하여 텐서를 하나 만들고, numpy로 변환하여 사용하였다. matplotlib에서 heatmap을 그리기 위해서는 matshow() 함수를 사용한다. 파라미터로 numpy행렬만 넣어주면 끝이다. colorbar는 오른쪽에 보이는 긴 막대기를 의미하고 값에 대한 색의 밝기 진하기를 나타낸다. 위 코드는 random 값으로 행렬을 채웠기 때문에 굉장히 난잡하게 나오지만, 아마 실제 값을 찍어보면 멋있게 나올것이다 (개인적으로 필자는 heatmap이 멋있어 보인다).

필자의 heatmap 사용 예는 모델에서 생성한 latent vector를 visualization 할때 사용하였다.