import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import pandas as pd 

from gwdr.data.data import load_dataset
from gwdr.src.scores import scores_clustdr


def retrieve_expnames(dir):
    folder_names = []
    for entry_name in os.listdir(dir):
        entry_path = os.path.join(dir, entry_name)
        if os.path.isdir(entry_path):
            if ('DR' in entry_name) or ('COOT' in entry_name):
                folder_names.append(entry_name)
            else:
                continue
    return folder_names


def compute_save_scores_batch_exp(dir, threshold=5, weighted=True, device='cpu', hyperbolic=False):
    folder_names = retrieve_expnames(dir)

    print('--- computing and saving scores ---' )
    for name in folder_names:

        sil_scores = {}
        hom_scores = {}

        exp_dict, embed_list = torch.load(dir+f'/{name}/embeddings.pt')
        exp_dict_, plans_list = torch.load(dir+f'/{name}/plans.pt')
        assert exp_dict == exp_dict_
        assert embed_list.keys() == plans_list.keys()
        
        # the same device may not be available if we compute scores a posteriori.
        #Y = load_dataset(exp_dict['dataset'], device=exp_dict['device'], Yonly=True)
        Y = load_dataset(exp_dict['dataset'], device=device, Yonly=True)
        
        if len(Y.shape) == 1:
            sil_scores, hom_scores = {}, {}
        elif len(Y.shape) == 2:
            sil_scores1, sil_scores2, hom_scores1, hom_scores2 = {}, {}, {}, {}
        else:
            raise ValueError('Y has too much labels (>2), code needs to be adapted')

        for key in embed_list.keys():

            if len(Y.shape) == 1:
                sil_scores[key], hom_scores[key] = [], []
            else: # len(Y.shape) == 2
                sil_scores1[key], sil_scores2[key], hom_scores1[key], hom_scores2[key] = [], [], [], []

            for seed in range(exp_dict["n_seeds"]):

                T, Z = plans_list[key][seed], embed_list[key][seed]
                
                if len(Y.shape) == 1:
                    if isinstance(T, float): # nan value that can happen with mirror descent not well calibrated
                        sil_score, hom_score = float('nan'), float('nan')
                    else:
                        sil_score, hom_score = scores_clustdr(T.to(device), Z.to(device), Y, threshold=threshold, weighted=weighted, hyperbolic=hyperbolic)
                        sil_score, hom_score = sil_score.item(), hom_score.item()
                    sil_scores[key].append(sil_score)
                    hom_scores[key].append(hom_score)
                
                else: # len(Y.shape) == 2
                    if isinstance(T, float):
                        sil_score1, hom_score1, sil_score2, hom_score2 = float('nan'), float('nan'), float('nan'), float('nan')
                    else:
                        T_, Z_ = T.to(device), Z.to(device)
                        sil_score1, hom_score1 = scores_clustdr(T_, Z_, Y[0], threshold=threshold, weighted=weighted, hyperbolic=hyperbolic)
                        sil_score2, hom_score2 = scores_clustdr(T_, Z_, Y[1], threshold=threshold, weighted=weighted, hyperbolic=hyperbolic)
                        sil_score1, hom_score1, sil_score2, hom_score2 = sil_score1.item(), hom_score1.item(), sil_score2.item(), hom_score2.item()
                    
                    sil_scores1[key].append(sil_score1)
                    sil_scores2[key].append(sil_score2)
                    hom_scores1[key].append(hom_score1)
                    hom_scores2[key].append(hom_score2)    
                
        if len(Y.shape) == 1:
            torch.save([exp_dict, sil_scores], exp_dict['log_dir']+'/sil_scores.pt')
            torch.save([exp_dict, hom_scores], exp_dict['log_dir']+'/hom_scores.pt')
        else: # len(Y.shape) == 2
            torch.save([exp_dict, sil_scores1], exp_dict['log_dir']+'/sil_scores.pt')
            torch.save([exp_dict, sil_scores2], exp_dict['log_dir']+'/sil_scores2.pt')
            torch.save([exp_dict, hom_scores1], exp_dict['log_dir']+'/hom_scores.pt')
            torch.save([exp_dict, hom_scores2], exp_dict['log_dir']+'/hom_scores2.pt')
        

def mean_std_scores(score_dict):
    keys = []
    mean_scores = []
    std_scores = []
    for key in score_dict.keys():
        keys.append(key)
        mean_scores.append(np.array(score_dict[key]).mean())
        std_scores.append(np.array(score_dict[key]).std())
    return np.array(keys), np.array(mean_scores), np.array(std_scores)


def plot_score_dict(ax, score_dict, label='default', coeff_std=0.5):
    x, y, y_std = mean_std_scores(score_dict)
    ax.plot(x, y, label=label)
    ax.fill_between(x, y-coeff_std*y_std, y+coeff_std*y_std, alpha=0.5)


def retrieve_saved_scores(dir, weighted=True):
    folder_names = retrieve_expnames(dir)
    sil_scores = {}
    hom_scores = {}
    
    n_labels = 1
    for name in folder_names:

        sil_scores[name] = []
        sil_scores[name].append(torch.load(dir+f'/{name}/sil_scores.pt')[1])   
        if os.path.isfile(dir+f'/{name}/sil_scores2.pt'): # if multiple labels
            sil_scores[name].append(torch.load(dir+f'/{name}/sil_scores2.pt')[1])  
            n_labels = 2

        hom_scores[name] = []
        hom_scores[name].append(torch.load(dir+f'/{name}/hom_scores.pt')[1])
        if os.path.isfile(dir+f'/{name}/hom_scores2.pt'):
            hom_scores[name].append(torch.load(dir+f'/{name}/hom_scores2.pt')[1]) 

    return sil_scores, hom_scores, n_labels


def plot_losses_batch_exp(dir):
    print('--- plotting training losses ---')
    folder_names = retrieve_expnames(dir)
    for name in folder_names:
        subdir = os.path.join(dir, name)
        losses = torch.load(subdir+'/losses.pt')[1]
        _, axes = plt.subplots(2, 5, figsize=(20, 8))
        for i,key in enumerate(losses.keys()):
            if i<10:
                axes[i//5, i%5].plot(losses[key][0])
                axes[i//5, i%5].set_title('output_sam = '+str(key))
        plt.savefig(subdir+'/training_loss.png', bbox_inches='tight')


def plot_scores_batch_exp(dir, weighted=True, hyperbolic=False):
    print('--- plotting scores ---')
    sil_scores, hom_scores, n_labels = retrieve_saved_scores(dir, weighted)
    assert sil_scores.keys() == hom_scores.keys()
    n_scores = 2 # silhouette and homogeneity
    if hyperbolic:
        dict_scores = {'silhouette (hyperbolic)' : sil_scores,
                        'homogeneity' : hom_scores}
    else:
        dict_scores = {'silhouette' : sil_scores,
                        'homogeneity' : hom_scores}
        
    _, axes = plt.subplots(n_labels, n_scores, figsize=(5 * n_scores, 3 * n_labels))
    if n_labels==1:
        for key in sil_scores.keys():
            for iscore, score in enumerate(dict_scores.keys()):
                plot_score_dict(axes[iscore], dict_scores[score][key][0], label=key)
                if iscore == 0.:
                    axes[0].legend(loc='lower left')
                axes[iscore].set_xlabel('output samples')
                axes[iscore].set_ylabel(score)
    elif n_labels==2:
        for key in sil_scores.keys():
            for iscore, score in enumerate(dict_scores.keys()):
                for level in range(n_labels):
                    plot_score_dict(axes[level, iscore], dict_scores[score][key][level], label=key)
                    if (level == 1) and (iscore == 0.):
                        axes[1, 0].legend(loc='lower left')
                    axes[level, iscore].set_xlabel('output samples')
                    axes[level, iscore].set_ylabel(f'label {level + 1}, {score}') 
    else:
        raise ValueError('n_labels not implemented')   
    plt.savefig(dir+'/final_plot.pdf', bbox_inches='tight')


def save_best_scores_batch_exp(dir, weighted=True):
    print('--- identifying best scores ---')
    # perform validation based on silhouette score + homogeneity score
    sil_scores, hom_scores, n_labels = retrieve_saved_scores(dir, weighted)
    assert sil_scores.keys() == hom_scores.keys()

    for n_label in range(n_labels):
        best_res_dict = {
            'exp_name':[],
            'best_metric_mean':[], 'best_metric_std':[], # aggregating silhouette and homogeneity
            'scaled_silhouette_mean':[], 'scaled_silhouette_std':[], # scaled silhouette corresponding to best metric
            'silhouette_mean':[], 'silhouette_std':[], # silhouette corresponding to best metric
            'homogeneity_mean':[], 'homogeneity_std':[], # homogeneity corresponding to best metric
            'n_prototypes':[] # corresponding number of prototypes
        }
    
        for exp_name in sil_scores.keys():
            list_n_prototypes = []
            list_metric_score_mean, list_metric_score_std = [], []
            list_sil_score_mean, list_sil_score_std = [], []
            list_scaled_sil_score_mean, list_scaled_sil_score_std = [], []
            list_hom_score_mean, list_hom_score_std = [], []
            
            for n_prot in sil_scores[exp_name][0].keys():
                list_n_prototypes.append(n_prot)
                local_sil_score = np.array(sil_scores[exp_name][n_label][n_prot]) # silhouette score per seed
                local_scaled_sil_score = (local_sil_score + 1.) / 2. # from [-1, 1] to [0, 1] to compare to homogeneity
                local_hom_score = np.array(hom_scores[exp_name][n_label][n_prot]) # silhouette score per seed
                local_metric = (local_scaled_sil_score + local_hom_score) / 2.
                list_metric_score_mean.append(local_metric.mean())
                list_metric_score_std.append(local_metric.std())
                list_sil_score_mean.append(local_sil_score.mean())
                list_sil_score_std.append(local_sil_score.std())
                list_scaled_sil_score_mean.append(local_scaled_sil_score.mean())
                list_scaled_sil_score_std.append(local_scaled_sil_score.std())
                list_hom_score_mean.append(local_hom_score.mean())
                list_hom_score_std.append(local_hom_score.std())

            best_idx = np.argmax(list_metric_score_mean)
            best_res_dict['exp_name'].append(exp_name)
            best_res_dict['best_metric_mean'].append(list_metric_score_mean[best_idx])
            best_res_dict['best_metric_std'].append(list_metric_score_std[best_idx])
            best_res_dict['scaled_silhouette_mean'].append(list_scaled_sil_score_mean[best_idx])
            best_res_dict['scaled_silhouette_std'].append(list_scaled_sil_score_std[best_idx])
            best_res_dict['silhouette_mean'].append(list_sil_score_mean[best_idx])
            best_res_dict['silhouette_std'].append(list_sil_score_std[best_idx])
            best_res_dict['homogeneity_mean'].append(list_hom_score_mean[best_idx])
            best_res_dict['homogeneity_std'].append(list_hom_score_std[best_idx])
            best_res_dict['n_prototypes'].append(list_n_prototypes[best_idx])

        best_res_df = pd.DataFrame(best_res_dict)
        if n_label == 0:
            best_res_df.to_csv(dir+'/final_best_scores.csv', index=False)
        else:
            best_res_df.to_csv(dir+f'/final_best_scores_label{n_label+1}.csv', index=False)
    
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='')
    parser.add_argument('--threshold', type=float, default=0)
    parser.add_argument('--weighted', type=bool, default=True)
    parser.add_argument('--hyperbolic', type=bool, default=False)
    
    args = parser.parse_args()
    path_config_folder = os.getcwd()+'/runs'
    path_config = os.path.join(path_config_folder, args.config)
    plot_losses_batch_exp(path_config)
    compute_save_scores_batch_exp(path_config, args.threshold, weighted=args.weighted, hyperbolic=args.hyperbolic)
    save_best_scores_batch_exp(path_config, weighted=args.weighted)
    plot_scores_batch_exp(path_config, weighted=args.weighted, hyperbolic=args.hyperbolic)