import os
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from metrics import compute_mcc


base_input_folder = 'experiment_output/gtex_experiment'

exp_folders = os.listdir(base_input_folder)

dpf_list = [0.0, 0.1, 0.01, 0.001, 0.0001]
lr_list = [0.001, 0.0001]
sp_list = [0, 1, 2]

model_col = list()
mcc_col = list()
mse_col = list()
sp_col = list()

for ef in exp_folders:
    input_folder = os.path.join(base_input_folder, ef)
    for sp in sp_list:
        # selecting the best baseline model
        best_baseline_lr = None
        best_baseline_result = np.inf
        for lr in lr_list:
            loss_list = list()
            for subdir in os.listdir(input_folder):
                if 'baseline' in subdir and float(subdir.split('_')[-2]) == lr and int(subdir.split('_')[-1]) == sp: # hard-coded to format as of May 9, may change
                    with open(os.path.join(input_folder, subdir, 'metric_results'), 'rb') as f:
                        metric_results = pickle.load(f)[0]
                    loss_list.append(metric_results['valid_mse'])
            if np.mean(loss_list) < best_baseline_result:
                best_baseline_lr = lr
                best_baseline_result = np.mean(loss_list)

        # selecting the best TBR model
        best_tbr_lr = None
        best_tbr_dpf = None
        best_tbr_result = np.inf
        for lr in lr_list:
            for dpf in dpf_list:
                loss_list = list()
                for subdir in os.listdir(input_folder):
                    if 'tbr' in subdir and float(
                            subdir.split('_')[-2]) == lr and float(
                            subdir.split('_')[-3]) == dpf and int(subdir.split('_')[-1]) == sp:  # hard-coded to format as of May 9, may change
                        with open(os.path.join(input_folder, subdir, 'metric_results'), 'rb') as f:
                            metric_results = pickle.load(f)[0]
                            loss_list.append(metric_results['valid_loss'])
                if np.mean(loss_list) < best_tbr_result:
                    best_tbr_result = np.mean(loss_list)
                    best_tbr_lr = lr
                    best_tbr_dpf = dpf
        print('parameters selected')
        print('Baseline LR: {}'.format(best_baseline_lr))
        print('TBR LR: {}'.format(best_tbr_lr))
        print('TBR DPF: {}'.format(best_tbr_dpf))

        tbr_mcc = list()
        tbr_mse = list()
        for subdir in os.listdir(input_folder):
            if 'tbr' in subdir and float(subdir.split('_')[-2]) == best_tbr_lr and float(subdir.split('_')[-3]) == best_tbr_dpf and int(subdir.split('_')[-1]) == sp:
                with open(os.path.join(input_folder, subdir, 'metric_results'), 'rb') as f:
                    metric_results = pickle.load(f)[0]
                mcc = compute_mcc(x=metric_results['test_embeddings'], y=np.asarray(metric_results['test_generative_latents'], dtype=float))
                tbr_mcc.append(mcc)
                tbr_mse.append(metric_results['test_mse'])
        mcc_col.append(np.mean(tbr_mcc))
        mse_col.append(np.mean(tbr_mse))


        model_col.append('TBR')
        sp_col.append(sp)

        baseline_mcc = list()
        baseline_mse = list()
        baseline_r2 = list()
        for subdir in os.listdir(input_folder):
            if 'baseline' in subdir and float(subdir.split('_')[-2]) == best_baseline_lr and int(subdir.split('_')[-1]) == sp:
                with open(os.path.join(input_folder, subdir, 'metric_results'), 'rb') as f:
                    metric_results = pickle.load(f)[0]
                    mcc = compute_mcc(x=metric_results['test_embeddings'],
                                      y=np.asarray(metric_results['test_generative_latents'], dtype=float))
                    baseline_mcc.append(mcc)
                    baseline_mse.append(metric_results['test_mse'])

        mcc_col.append(np.mean(baseline_mcc))
        mse_col.append(np.mean(baseline_mse))
        model_col.append('Baseline')
        sp_col.append(sp)


data = np.vstack([model_col, mcc_col, mse_col, sp_col]).transpose()
df = pd.DataFrame(data=data, columns=['Model', 'MCC', 'MSE', 'K-Sparse Setting'])
df = df.astype({'MCC': 'float', 'MSE': 'float'})
sns.set_style("whitegrid")
barplot_mcc = sns.catplot(x="K-Sparse Setting", y="MCC", hue='Model', kind="bar", data=df, palette="mako")
barplot_mcc.set(title='MCC Scores - Multi Exp')
barplot_mcc.savefig('plots/mcc_bar_{}.png'.format(base_input_folder.split('/')[-1]))
plt.grid(b=None)
plt.close()
boxplot_mcc = sns.catplot(x="K-Sparse Setting", y="MCC", hue='Model', kind="swarm", data=df, palette="mako")
plt.grid(b=None)
boxplot_mcc.savefig('plots/mcc_scatter_{}.pdf'.format(base_input_folder.split('/')[-1]))
plt.close()
# plotting the mse
boxplot_mse = sns.catplot(x="K-Sparse Setting", y="MSE", hue='Model', kind="swarm", data=df, palette="mako")
plt.grid(b=None)
boxplot_mse.savefig('plots/mse_scatter_{}.pdf'.format(base_input_folder.split('/')[-1]))
plt.close()

