"""

cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 ./em/projects/neurips2023/plots/make_perturbation_histograms.py


"""

import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from em.projects.neurips2023.data.main_perturbations import DATA

SAVE_DIR = os.path.expanduser("~/Desktop/projects_data/extract_merge1/neurips2023/images")


def plot(values, title: str, filename: str):
    n_bins = 20
    # plt.rcParams["figure.figsize"] = (7, 2.5)
    plt.rcParams["figure.figsize"] = (2 * 7 / 3, 2.5)
    #
    fig, ax = plt.subplots()
    #
    plt.title(title)
    ax.hist(
        np.log2(values['kls']),
        n_bins,
        alpha=0.65,
        label='KL',
    )
    ax.hist(
        np.log2(values['norms']),
        n_bins,
        alpha=0.65,
        label='PEF Norm',
    )
    # ax.hist(
    #     [np.log2(values['kls']), np.log2(values['norms'])],
    #     n_bins,
    #     alpha=0.5,
    #     label=['KL', 'PEF Norm'],
    # )
    ax.set_xlabel('Log_2 Ratio')
    ax.set_ylabel('Count')
    fig.tight_layout()
    ax.legend()
    filename = f'{filename}.pdf'
    filepath = os.path.join(SAVE_DIR, filename)
    fig.savefig(filepath)
    # plt.show()
    plt.close(fig)


plot(DATA['qqp'], 'QQP', 'main_qqp_perturb_ratios')
plot(DATA['snli'], 'NLI', 'main_snli_perturb_ratios')
plot(DATA['imagenet'], 'Vision', 'main_imagenet_perturb_ratios')
