import matplotlib.pyplot as plt
import os

def plot_heatmap(matrix, epoch, title):

    matrix = matrix.detach().cpu().numpy()

    plt.imshow(matrix, cmap='hot')
    plt.colorbar()
    plt.title(title + f'_epoch{epoch}')
    plt.savefig(title + f'epoch{epoch}.png')
    plt.close('all')
    
