import seaborn as sns
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


def tidy_plot(ax):
    for _, spine in ax.spines.items():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(1)


def main():
    fig, ax = plt.subplots(1, 1, sharex=True, figsize=(10,2.5))

    x_labels = ["retinopathy", "flowers", "caltech", "pets", "dtd", "eurosat", "camelyon", "resics", "cifar100", "sun",
                "kitti", "dmlab", "clevr-c", "snorb-el", "snorb-az", "dsprites-o", "clevr-d", "svhn", "dsprites-l"]
    y_labels = ["1", "2", "4", "8", '$\infty$']

    file = './heatmap_data.txt'

    data = np.genfromtxt(file, delimiter=',')
    heatmap = sns.heatmap(data, cbar=True, ax=ax, annot=True, fmt='.1f', cmap=matplotlib.cm.seismic, vmin=-50, vmax=50,
                          cbar_kws={"label": "Red: FiLM better\nBlue: Head better"})

    ax.set_yticklabels(y_labels, fontsize='large', rotation=0)
    ax.set_ylabel('$ϵ$', fontsize='large')
    ax.set_xticklabels(x_labels, rotation=30, fontsize='medium', ha="right")
    tidy_plot(ax)
    fig.tight_layout()
    plt.savefig('./plots/vtab_heatmap.pdf', bbox_inches='tight')


if __name__ == '__main__':
    main()