import os
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from metrics import compute_mcc

os.makedirs('figures', exist_ok=True)

base_dir = 'experiment_output/simulated_exp_tanh_1_treedepth_7'
sparsity_settings = [0, 1, 2, 3, 4, 5]


model_dict = {
    'tbr': {
        'true_latents': dict(),
        'embeddings': dict(),
        'mcc_scores': dict(),
        'MSE': dict()
    },
    'baseline': {
        'true_latents': dict(),
        'embeddings': dict(),
        'mcc_scores': dict(),
        'MSE': dict()
    }
}

for c in sparsity_settings:
    for m in model_dict.keys():
        for k in model_dict[m].keys():
            model_dict[m][k][c] = list()


for file in os.listdir(base_dir):
    for s in sparsity_settings:
        mod = file.split('_')[0]
        if (mod == 'baseline' and float(file.split('_')[-1]) == float(s)) or (mod == 'tbr' and float(file.split('_')[-1]) == float(s)):
            with open(os.path.join(base_dir, file, 'metric_results'), 'rb') as f:
                metric_results = pickle.load(f)[0]
            model_dict[mod]['true_latents'][s].append(metric_results['test_generative_latents'])
            model_dict[mod]['embeddings'][s].append(metric_results['test_embeddings'])
            model_dict[mod]['MSE'][s].append(metric_results['test_mse'])


"""
For each setting, calculating the max mutual information and max correlation between each true generative factor
and each of the embedding elements
"""

for m in model_dict.keys():
    for c in sparsity_settings:
        print(c)
        latent_dim = model_dict[m]['true_latents'][c][0].shape[-1]
        embedding_dim = model_dict[m]['embeddings'][c][0].shape[-1]
        repetitions = len(model_dict[m]['MSE'][c])
        for i in range(repetitions):
            model_dict[m]['mcc_scores'][c].append(compute_mcc(x=model_dict[m]['embeddings'][c][i],
                            y=model_dict[m]['true_latents'][c][i]))

# plotting mcc
model_col = list()
sparsity_col = list()
mcc_col = list()
mse_col = list()

mod_title = {
    'tbr': 'TBR',
    'baseline': 'Baseline'
}

for m in model_dict.keys():
    for c in sparsity_settings:
        mcc_list = model_dict[m]['mcc_scores'][c]
        mcc_col.extend(mcc_list)
        model_col.extend([mod_title[m] for _ in range(len(mcc_list))])
        sparsity_col.extend([c for _ in range(len(mcc_list))])
        mse_list = model_dict[m]['MSE'][c]
        mse_col.extend(mse_list)


data = np.vstack([model_col, sparsity_col, mcc_col, mse_col]).transpose()
df = pd.DataFrame(data=data, columns=['Model', 'K-Sparse Setting', 'MCC', 'MSE'])
df = df.astype({'K-Sparse Setting': 'int', 'MCC': 'float', 'MSE': 'float'})
df.to_csv('mcc_sparse.csv', index=False)

x = 8
c = 0.6
f, ax = plt.subplots(1,1, figsize=(x, x))



boxplot_mcc = sns.boxplot(x="K-Sparse Setting", y="MCC", hue='Model', data=df, ax=ax, palette="mako")
boxplot_mcc.figure.savefig('figures/mcc_{}.pdf'.format(base_dir.split('/')[-1]))
plt.close()

x = 8
c = 0.6
f, ax = plt.subplots(1,1, figsize=(x, x))

# todo: check double plotting bug
boxplot_mse = sns.boxplot(x="K-Sparse Setting", y="MSE", hue='Model', data=df, ax=ax, palette="mako")
boxplot_mse.figure.savefig('figures/mse_{}.pdf'.format(base_dir.split('/')[-1]))

# plotting the mse


