
import os
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from metrics import compute_mcc


base_input_folder = 'experiment_output/gtex_transfer_test_experiment'

exp_folders = os.listdir(base_input_folder)

dpf_list = [0.0, 0.1, 0.01, 0.001, 0.0001]
lr_list = [0.001, 0.0001]
sp_list = [0, 1, 2, 3]

model_col = list()
mcc_col = list()
mse_col = list()
r2_col = list()
sp_col = list()

for ef in exp_folders:
    input_folder = os.path.join(base_input_folder, ef)
    # gathering the sets of transfer cell types
    target_cells = os.listdir(input_folder)
    target_cells = [t.split('_')[-1] for t in target_cells]

    for t in target_cells:
        for sp in sp_list:
            # selecting the best baseline model
            best_baseline_lr = None
            best_baseline_result = np.inf
            for lr in lr_list:
                loss_list = list()
                for subdir in os.listdir(input_folder):
                    if 'baseline' in subdir and float(subdir.split('_')[-3]) == lr and int(subdir.split('_')[-2]) == sp and t in subdir: # hard-coded to format as of Apr 14
                        with open(os.path.join(input_folder, subdir, 'metric_results'), 'rb') as f:
                            metric_results = pickle.load(f)[0]
                        loss_list.append(metric_results['valid_mse'])
                if np.mean(loss_list) < best_baseline_result:
                    best_baseline_lr = lr
                    best_baseline_result = np.mean(loss_list)

            # selecting the best tbr model
            best_tbr_lr = None
            best_tbr_dpf = None
            best_tbr_result = np.inf
            for lr in lr_list:
                for dpf in dpf_list:
                    loss_list = list()
                    for subdir in os.listdir(input_folder):
                        if 'tbr' in subdir and float(
                                subdir.split('_')[-3]) == lr and float(
                                subdir.split('_')[-4]) == dpf and int(subdir.split('_')[-2]) == sp and t in subdir:  # hard-coded to format as of Apr 14
                            with open(os.path.join(input_folder, subdir, 'metric_results'), 'rb') as f:
                                metric_results = pickle.load(f)[0]
                                loss_list.append(metric_results['valid_loss'])
                    if np.mean(loss_list) < best_tbr_result:
                        best_tbr_result = np.mean(loss_list)
                        best_tbr_lr = lr
                        best_tbr_dpf = dpf
            print('parameters selected')
            print('Baseline LR: {}'.format(best_baseline_lr))
            print('TBR LR: {}'.format(best_tbr_lr))
            print('TBR DPF: {}'.format(best_tbr_dpf))

            tbr_mcc = list()
            tbr_mse = list()
            for subdir in os.listdir(input_folder):
                if 'tbr' in subdir and float(subdir.split('_')[-3]) == best_tbr_lr and float(subdir.split('_')[-4]) == best_tbr_dpf and int(subdir.split('_')[-2]) == sp and t in subdir:
                    with open(os.path.join(input_folder, subdir, 'metric_results'), 'rb') as f:
                        metric_results = pickle.load(f)[0]
                    mcc = compute_mcc(x=metric_results['test_embeddings'], y=np.asarray(metric_results['test_generative_latents'], dtype=float))
                    tbr_mcc.append(mcc)
                    tbr_mse.append(metric_results['test_mse'])
            mcc_col.append(np.mean(tbr_mcc))
            mse_col.append(np.mean(tbr_mse))


            model_col.append('TBR')
            sp_col.append(sp)

            baseline_mcc = list()
            baseline_mse = list()
            for subdir in os.listdir(input_folder):
                if 'baseline' in subdir and float(subdir.split('_')[-3]) == best_baseline_lr and int(subdir.split('_')[-2]) == sp and t in subdir:
                    with open(os.path.join(input_folder, subdir, 'metric_results'), 'rb') as f:
                        metric_results = pickle.load(f)[0]
                        mcc = compute_mcc(x=metric_results['test_embeddings'],
                                          y=np.asarray(metric_results['test_generative_latents'], dtype=float))
                        baseline_mcc.append(mcc)
                        baseline_mse.append(metric_results['test_mse'])

            mcc_col.append(np.mean(baseline_mcc))
            mse_col.append(np.mean(baseline_mse))
            model_col.append('Baseline')
            sp_col.append(sp)


data = np.vstack([model_col, mcc_col, mse_col, sp_col, r2_col]).transpose()
df = pd.DataFrame(data=data, columns=['Model', 'MCC', 'MSE', 'K-Sparse Setting'])
df = df.astype({'MCC': 'float', 'MSE': 'float'})
sns.set_style("whitegrid")

barplot_mse = sns.catplot(x="K-Sparse Setting", y="MSE", hue='Model', kind="bar", data=df, palette="mako")
barplot_mse.set(title='MSE - Multi Exp')
plt.grid(b=None)
barplot_mse.savefig('plots/transfer_mse_bar_{}.pdf'.format(base_input_folder.split('/')[-1]))

