

import copy
import time

import mlflow
import numpy as np
import yaml

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d, Axes3D #<-- Note the capitalization! 

import numpy as np
import pandas as pd
import seaborn as sns

import eval
from src.model_clf import XTab as XTab_clf
from utils.arguments import get_arguments, get_config, print_config_summary
from utils.load_data import Loader
from utils.utils import set_dirs, run_with_profiler, update_config_with_model_dims
from sklearn.datasets import load_diabetes, load_iris, load_breast_cancer, load_wine, load_boston, fetch_california_housing, fetch_kddcup99, fetch_openml

sns.set(rc={'figure.dpi':100, 'savefig.dpi':100, 'figure.figsize':(4,4), 'axes.facecolor':'white', 'figure.facecolor':'white'}, font_scale=1.1) 
sns.set_style(style='white') 





income_feature_types = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
    'age',
    'fnlwgt',
    'education-num',
    'capital-gain',
    'capital-loss',
    'hours-per-week'
]



def get_kvs(di, is_tuple=False, idx=0):
    values = []
    for k, v in di.items():
        values.append(v if not is_tuple else v[idx])
    return values

def train(config_o, data_loader, save_weights=True, seed=None):
    """Utility function for training and saving the model.
    Args:
        config (dict): Dictionary containing options and arguments.
        data_loader (IterableDataset): Pytorch data loader.
        save_weights (bool): Saves model if True.

    """
    config = config_o
    model = XTab_clf(config_o) 
    mask_state_dicts_l = np.load(model._results_path+"/"+config["dataset"]+'_mask_state_dicts_l.npy', allow_pickle=True).ravel()
    enc_state_dicts_l = np.load(model._results_path+"/"+config["dataset"]+'_enc_state_dicts_l.npy', allow_pickle=True).ravel()
    clf_state_dicts_l = np.load(model._results_path+"/"+config["dataset"]+'_clf_state_dicts_l.npy', allow_pickle=True).ravel()

    acc_auc_dicts = np.load(model._results_path+"/"+config["dataset"]+'_acc_auc_dicts.npy', allow_pickle=True).ravel()[0]
    feature_tr_dicts = np.load(model._results_path+"/"+config["dataset"]+'_feature_tr_dicts.npy', allow_pickle=True).ravel()[0]
    feature_te_dicts = np.load(model._results_path+"/"+config["dataset"]+'_feature_te_dicts.npy', allow_pickle=True).ravel()[0]
    
    print(mask_state_dicts_l)
    print(acc_auc_dicts)
    print(feature_te_dicts)
    
    accs = get_kvs(acc_auc_dicts, is_tuple=True, idx=0)
    aucs = get_kvs(acc_auc_dicts, is_tuple=True, idx=1)
    features = get_kvs(feature_te_dicts)
    
    
#     mask_state_dicts_l = mask_state_dicts_l[:10]
#     enc_state_dicts_l = enc_state_dicts_l[:10]
#     clf_state_dicts_l = clf_state_dicts_l[:10]

#     features = features[:10]
#     accs=accs[:10]
#     aucs=aucs[:10]
    
    ranking = {}
    
    for i, feat_list_in_fold in enumerate(features):
        
        fold_num = i+1
        ranking[fold_num] = {}
        
        for j, feat in enumerate(feat_list_in_fold):
            if feat in ranking[fold_num]:
                ranking[fold_num][feat].append(j+1) 
            else:
                ranking[fold_num][feat] = [j+1]
    
    print(ranking)
    
    rankings2 = {}
    
    feature_types = income_feature_types
    
    if config_o["dataset"][:6] == "invase":
        feature_types = ["f"+str(i) for i in range(1, 21)]
 
    if config_o["dataset"] == "blog":
        feature_types = ["f"+str(i) for i in range(1, 281)]
        
    if config_o["dataset"][:3] == "l2x":
        feature_types = ["f"+str(i) for i in range(1, 11)]
        
    if config_o["dataset"] in ["mushroom", "electricity", "credit"]:
        
        data_dict = {"mushroom": 24, "electricity":151, "credit":31}
        data_id = data_dict[config_o["dataset"]]
        data_obj = fetch_openml(data_id=data_id)
        feature_types = data_obj.feature_names
        
        feature_types = config_o["mushroom_features"] if config_o["dataset"] == "mushroom"  else  feature_types
        feature_types = config_o["credit_features"] if config_o["dataset"] == "credit" else feature_types

        
    if  config_o["dataset"] in ["wine", "diabetes", "breast_cancer", "boston", "california"]:
        
        data_dict = {"wine": load_wine(), "diabetes": load_diabetes(), "breast_cancer":load_breast_cancer(), 
                     "boston": load_boston(), "california":fetch_california_housing()}
        
        data_obj = data_dict[config_o["dataset"]]
        feature_types = data_obj.feature_names
                                
                                
    if config_o["dataset"] == "mnist":
        feature_types = [str(i) for i in range(1, 784)]        
        
        
    for feat in feature_types:
        rankings2[feat] = []
        
        for i, feat_list_in_fold in enumerate(features):
            feat_not_found=True
            for j, f in enumerate(feat_list_in_fold):
                if f==feat:
                    rankings2[feat].append(j+1)
                    feat_not_found=False
                    continue
        
            if feat_not_found:
                rankings2[feat].append(None)
            
        
        
    print(rankings2)
    tbd = []
    for k, v in rankings2.items():
        if v==len(accs)*[None]:
            tbd.append(k)
#         elif config["dataset"] == "blog":
#             if None in v:
#                 tbd.append(k)
#         else:
#             pass
            
    for k in tbd:
        del rankings2[k]
    print(rankings2)
    
    acc_mean = np.mean(accs)
    acc_std = np.std(accs)
    
    acc_pd = pd.DataFrame({'mean':acc_mean, 'stdev': acc_std}, index=[0])
    acc_pd.to_csv(model._results_path + '/acc_stats.csv')

#     rankings2['fold'] = list(range(1,11))

    #rankings2["Acc"] = [10*i+2 for i in accs]
    #config["kfold"]=10
    ranking_df = pd.DataFrame(rankings2, index=list(range(1,len(mask_state_dicts_l)+1)))
    accs_df = pd.DataFrame(accs, index=list(range(1,len(mask_state_dicts_l)+1)))

    sns.relplot(data=ranking_df,  kind='line')
    plt.savefig(model._results_path+"/"+config["dataset"]+"_test2_seed"+str(seed)+".png")
    
    
    fig, (ax1, ax2) = plt.subplots(2, 1)
    
    sns.lineplot(data=accs_df,  ax=ax2)
    sns.lineplot(data=ranking_df,  ax=ax1)
    plt.legend(bbox_to_anchor=(1.01, 1), borderaxespad=0)
    plt.tight_layout()           
    # Put the legend out of the figure

    
    #ax2.legend(handles=[a.lines[0] for a in [ax1,ax2]], labels=["Ranking", "Val. Accuracy"])
    ax1.set_ylabel('Ranking', color='g')
    ax2.set_ylabel('Val. Accuracy', color='b')
    plt.savefig(model._results_path+"/"+config["dataset"]+"_test_seed"+str(seed)+".png")

    

#     idxs = [i for i in range(len(accs)) if accs[i]> sum(accs)/len(accs)]

#     # Initialize the mask with weighted mask learned during cross-validation
#     for k in mask_state_dicts_l[0]:
#         mask_state_dicts_l[0][k] = sum([mask_state_dicts_l[j][k] for j in idxs])/len(idxs)


#     for k in clf_state_dicts_l[0]:
#         clf_state_dicts_l[0][k] = sum([clf_state_dicts_l[j][k] for j in idxs])/len(idxs)

#     for k in enc_state_dicts_l[0]:
#         enc_state_dicts_l[0][k] = sum([enc_state_dicts_l[j][k] for j in idxs])/len(idxs)


    most_important_feature_l = []
    most_important_feature_l2 = []
    most_important_feature_l3 = []
    most_important_feature_l4 = []
    most_important_feature_l5 = []
    most_important_feature_l6 = []
    most_important_feature_l7 = []
    most_important_feature_l8 = []
    most_important_feature_l9 = []
    most_important_feature_l10 = []

    mif_dict = {}
    mif_dict2 = {}
    acc_l = []
    upto= 10
    
    for jj in range(1,len(mask_state_dicts_l)+1):
        
        mask_state_dicts_l = np.load(model._results_path+"/"+config["dataset"]+'_mask_state_dicts_l.npy', allow_pickle=True).ravel()
        enc_state_dicts_l = np.load(model._results_path+"/"+config["dataset"]+'_enc_state_dicts_l.npy', allow_pickle=True).ravel()
        clf_state_dicts_l = np.load(model._results_path+"/"+config["dataset"]+'_clf_state_dicts_l.npy', allow_pickle=True).ravel()
        
        for k in mask_state_dicts_l[0]:
            mask_state_dicts_l[0][k] = sum([sd[k] for sd in mask_state_dicts_l[:jj]])/len(mask_state_dicts_l[:jj])


        for k in clf_state_dicts_l[0]:
            clf_state_dicts_l[0][k] = sum([sd[k] for sd in clf_state_dicts_l[:jj]])/len(clf_state_dicts_l[:jj])

        for k in enc_state_dicts_l[0]:
            enc_state_dicts_l[0][k] = sum([sd[k] for sd in enc_state_dicts_l[:jj]])/len(enc_state_dicts_l[:jj])



        # Choose the default fold to load
        config = copy.deepcopy(config_o)
        config["fold_num"] = 0
        config["validate"] = False
        config["training_data_ratio"] = 1.0
        config["test_mode"] = True

        # Get data loader for first dataset.
        #config_o["seed"] = 57
        data_loader = Loader(config, dataset_name=config["dataset"])
        config = update_config_with_model_dims(data_loader, config)

        # Turn off valiation
        config["validate"] = False
        # Get all of available training set for evaluation (i.e. no need for validation set)
        config["training_data_ratio"] = 1.0


        # Disable training the mask
        config["train_mask"] = False
        # Use global mask during evaluation
        config["use_mask_g"] = False
        config["seed"] = seed

        model = XTab_clf(config) 
        #model.load_models()

        # Overwrite the mask's weights
        model.mask_g.load_state_dict(mask_state_dicts_l[0])
        #model.encoder.encoder.load_state_dict(enc_state_dicts_l[0])
        #model.clf.load_state_dict(clf_state_dicts_l[0])

        # Fit the model to the data
        #model.fit(data_loader)

        #model.save_weights()


        config = copy.deepcopy(config_o)
        # Disable adding noise since we are in evaluation mode
        config["add_noise"] = False
        # Turn off valiation
        config["validate"] = False
        # Get all of available training set for evaluation (i.e. no need for validation set)
        config["training_data_ratio"] = 1.0
        # Disable training the mask
        config["train_mask"] = False
        # Use global mask during evaluation
        config["use_mask_g"] = True
        config["fold_num"] = 0
        config["test_mode"] = True

        # Run Evaluation
        feature_importance_tuple = eval.main(config, mask_g=model.mask_g)
        feature_importance_te = feature_importance_tuple[1]
        feature_importance_te_mf = feature_importance_tuple[-1]
        instancewise_ranking_te = feature_importance_tuple[-2]

#         most_important_feature_te = feature_importance_te[0]

        most_important_feature_l.append(feature_importance_te[0])
        most_important_feature_l2.append(feature_importance_te[1])
        most_important_feature_l3.append(feature_importance_te[2])
        most_important_feature_l4.append(feature_importance_te[3])
        most_important_feature_l5.append(feature_importance_te[4])
        most_important_feature_l6.append(feature_importance_te[5])
        most_important_feature_l7.append(feature_importance_te[6])
        most_important_feature_l8.append(feature_importance_te[7])
        most_important_feature_l9.append(feature_importance_te[8])
        most_important_feature_l10.append(feature_importance_te[9])
          
        
        
        acc_l.append(feature_importance_tuple[2])
        mif_dict[str(jj)] = {str(i+1): v for i, v in enumerate(feature_importance_te[:upto])}
    
    
    mg_features = [feature_importance_te[i] for i in range(10)] + [feature_importance_tuple[2]]
    mf_features = [feature_importance_te_mf[i] for i in range(10)] + [feature_importance_tuple[2]]
    instance_features = [instancewise_ranking_te[i] for i in range(10)]

    ##########################################################################
    pkmn_type_colors = ['#78C850',  # Grass
                    '#F08030',  # Fire
                    '#6890F0',  # Water
                    '#A8B820',  # Bug
                    '#A8A878',  # Normal
                    '#A040A0',  # Poison
                    '#F8D030',  # Electric
                    '#E0C068',  # Ground
                    '#EE99AC',  # Fairy
                    '#C03028',  # Fighting
                    '#F85888',  # Psychic
                    '#B8A038',  # Rock
                    '#705898',  # Ghost
                    '#98D8D8',  # Ice
                    '#7038F8',  # Dragon
                    '#00AFBB',
                    '#C4961A',
                    '#0000FF',
                    '#FFDD44',
                    '#00FFFF',
                    '#008080',
                    '#FF00FF',
                    '#f78b2b',
                    '#9ff72b',
                    '#2b87f7',
                    '#6b8eb8',
                    '#a26bb8',
                    '#d64ca8',
                    '#c9d64c',
                    '#9fa380',
                    '#654957',
                    '#ff5005',
                    '#a33100',
                   ]
    
    pkmn_type_colors = 4*pkmn_type_colors

    ranked_features = [most_important_feature_l,
                       most_important_feature_l2,
                       most_important_feature_l3,
                       most_important_feature_l4,
                       most_important_feature_l5,
                       most_important_feature_l6,
                       most_important_feature_l7,
                       most_important_feature_l8,
                       most_important_feature_l9,
                       most_important_feature_l10]
    
    ranked_features_dict = {}
    
    
    for i, rf in enumerate(ranked_features):
        for j, feat in enumerate(rf):
            if feat in ranked_features_dict:
                ranked_features_dict[feat][j] = i+1
            else:
                ranked_features_dict[feat] = len(mask_state_dicts_l)*[None]
                ranked_features_dict[feat][j] = i+1
                
        
    print(ranked_features_dict)
    
    ranked_features_df = pd.DataFrame(ranked_features_dict, index=list(range(1, len(mask_state_dicts_l)+1)))

    print(ranked_features_df)
    print(ranked_features)
    
    if config_o["dataset"] in ["blog", "mushroom", "breast_cancer"]:
        ranked_features_df  = ranked_features_df.iloc[: , :10]
    
    if config_o["dataset"][:6] == "invase":
        ranked_features_df  = ranked_features_df.iloc[: , :10]
        

    pal = sns.color_palette("tab20")
    pkmn_type_colors = pal.as_hex() + pkmn_type_colors
#     ranked_features = ['f52', 'f54', 'f51', 'f20', 'f55', 'f101', 'f21', 'f12', 'f11', 'f14', 'f216'] if config_o["dataset"] == "blog" else ranking_df.columns
    
    ranked_features = ['spore-print-color', 'veil-color', 'stalk-root',  'odor', 'gill-size',  'gill-spacing', 'stalk-surface-above-ring',  'bruises%3F', 'stalk-shape','stalk-color-below-ring', 'veil-type'] if config_o["dataset"] == "mushroom" else ranked_features_df.columns
    
    
    if config_o["dataset"] in ["income"]:
        ranked_features = ['marital-status', 'occupation', 'relationship',  'race', 'sex',  'native-country', 'age',  'education-num', 'capital-gain', 'capital-loss', 'hours-per-week', 'workclass', 'education', 'fnlwgt']
        
    if config_o["dataset"] in ["blog"]:
        ranked_features  = ['f21', 'f52', 'f4', 'f15','f11', 'f12', 'f14', 'f54',
                            'f20', 'f188', 'f10', 'f21', 'f22', 'f29', 'f40', 'f60', 'f62', 'f64', 'f75',
                            'f81', 'f84', 'f90', 'f100', 'f101', 'f106', 'f109', 'f128', 'f130',
                            'f139', 'f154', 'f166', 'f169', 'f19', 'f55', 'f191', 'f200', 'f204', 'f206',
                            'f210', 'f212', 'f245', 'f250'] 
        ranked_features  = ranked_features + ['f1', 'f2', 'f4', 'f5', 'f52', 'f11', 'f12', 'f14', 'f15', 
       'f26', 'f27', 'f40', 'f46', 'f55', 'f62', 'f71', 'f51',
       'f86', 'f89', 'f101', 'f108', 'f124', 'f135', 'f138', 'f149', 'f156',
       'f161', 'f173', 'f180', 'f196', 'f213', 'f216', 'f223', 'f268', 'f271',
       'f275', 'f276', 'f277', 'f132', 'f220']
        
        for arr in [['f51', 'f52', 'f52', 'f52', 'f52', 'f52', 'f52', 'f52', 'f52', 'f52'], ['f52', 'f51', 'f51', 'f51', 'f51', 'f55', 'f51', 'f51', 'f51', 'f51'], ['f55', 'f55', 'f55', 'f55', 'f54', 'f54', 'f55', 'f54', 'f54', 'f54'], ['f15', 'f15', 'f54', 'f54', 'f55', 'f51', 'f54', 'f55', 'f55', 'f55'], ['f54', 'f54', 'f15', 'f15', 'f15', 'f21', 'f21', 'f15', 'f15', 'f21'], ['f20', 'f21', 'f20', 'f20', 'f20', 'f15', 'f15', 'f21', 'f21', 'f15'], ['f6', 'f20', 'f21', 'f12', 'f21', 'f20', 'f20', 'f20', 'f20', 'f10'], ['f62', 'f11', 'f12', 'f21', 'f12', 'f10', 'f10', 'f10', 'f10', 'f20'], ['f2', 'f17', 'f11', 'f16', 'f10', 'f12', 'f12', 'f12', 'f12', 'f12'], ['f14', 'f62', 'f60', 'f11', 'f153', 'f60', 'f60', 'f14', 'f14', 'f11']]:
            
            ranked_features  = ranked_features + arr
            
        for arr in [['f54', 'f54', 'f52', 'f52', 'f52', 'f52', 'f52', 'f52', 'f52', 'f52'], ['f52', 'f52', 'f51', 'f51', 'f51', 'f51', 'f51', 'f54', 'f54', 'f51'], ['f21', 'f10', 'f54', 'f54', 'f20', 'f54', 'f54', 'f51', 'f51', 'f54'], ['f117', 'f21', 'f10', 'f10', 'f54', 'f10', 'f10', 'f10', 'f10', 'f10'], ['f88', 'f11', 'f20', 'f20', 'f10', 'f20', 'f20', 'f55', 'f55', 'f55'], ['f3', 'f51', 'f11', 'f11', 'f55', 'f11', 'f55', 'f20', 'f20', 'f20'], ['f41', 'f20', 'f21', 'f55', 'f11', 'f55', 'f11', 'f11', 'f11', 'f11'], ['f45', 'f12', 'f12', 'f5', 'f21', 'f21', 'f5', 'f21', 'f21', 'f21'], ['f93', 'f5', 'f5', 'f21', 'f5', 'f5', 'f15', 'f5', 'f5', 'f15'], ['f60', 'f14', 'f101', 'f12', 'f25', 'f15', 'f12', 'f15', 'f15', 'f5']]:
            ranked_features  = ranked_features + arr
        
        ranked_features = list(set(ranked_features))
        
    
    
    if config_o["dataset"][:6] == "invase":
        ranked_features = ["f"+str(i) for i in range(1, 21)]

    if config_o["dataset"][:3] == "l2x":
        ranked_features = ["f"+str(i) for i in range(1, 11)]
        
    if config_o["dataset"] == "income":
        custom_colors = {ranked_features[i]: pkmn_type_colors[i] for i in range(len(ranked_features))}
    else:
        custom_colors = {f: pkmn_type_colors[i] for i, f in enumerate(ranked_features)}
    
    if config_o["dataset"] in ["wine", "diabetes", "breast_cancer", "boston", "california"]:
        custom_colors = {f: pkmn_type_colors[i] for i, f in enumerate(feature_types)}
        custom_colors['worst texture'] = '#d62728'
    
    if config_o["dataset"] in ["mushroom"]:
        custom_colors = {f: pkmn_type_colors[i] for i, f in enumerate(feature_types)}        
        
        
    
    print(custom_colors)
    
    if config_o["dataset"] in ["blog"]:
        custom_colors = {f: pkmn_type_colors[i] for i, f in enumerate(ranked_features)}
        overwrite_dict = {'f15': '#1f77b4', 'f108': '#aec7e8', 'f149': '#ff7f0e', 'f51': '#ffbb78', 'f27': '#2ca02c', 'f210': '#98df8a', 'f20': '#d62728', 'f86': '#ff9896', 'f154': '#9467bd', 'f213': '#c5b0d5', 'f54': '#8c564b', 'f268': '#c49c94', 'f245': '#e377c2', 'f138': '#f7b6d2', 'f132': '#7f7f7f', 'f139': '#c7c7c7', 'f62': '#bcbd22', 'f46': '#dbdb8d', 'f223': '#17becf', 'f220': '#9edae5', 'f101': '#78C850', 'f166': '#F08030', 'f14': '#6890F0', 'f250': '#A8B820', 'f212': '#A8A878', 'f5': '#A040A0', 'f271': '#F8D030', 'f55': '#E0C068', 'f90': '#EE99AC', 'f204': '#C03028', 'f12': '#F85888', 'f81': '#B8A038', 'f156': '#705898', 'f130': '#98D8D8', 'f180': '#7038F8', 'f128': '#00AFBB', 'f19': '#C4961A', 'f124': '#0000FF', 'f84': '#FFDD44', 'f135': '#00FFFF', 'f52': '#008080', 'f29': '#FF00FF', 'f26': '#f78b2b', 'f173': '#9ff72b', 'f216': '#2b87f7', 'f169': '#6b8eb8', 'f2': '#a26bb8', 'f71': '#d64ca8', 'f206': '#c9d64c', 'f106': '#9fa380', 'f109': '#654957', 'f275': '#ff5005', 'f22': '#a33100', 'f277': '#78C850', 'f40': '#F08030', 'f100': '#6890F0', 'f191': '#A8B820', 'f200': '#A8A878', 'f10': '#A040A0', 'f64': '#F8D030', 'f21': '#E0C068', 'f1': '#EE99AC', 'f89': '#C03028', 'f188': '#F85888', 'f75': '#B8A038', 'f276': '#705898', 'f60': '#98D8D8', 'f11': '#7038F8', 'f196': '#00AFBB', 'f161': '#C4961A', 'f4': '#0000FF', 'f262': '#F89988'}
        for k,v in overwrite_dict.items():
            custom_colors[k] = v


# {'f5': '#1f77b4', 'f188': '#aec7e8', 'f109': '#ff7f0e', 'f204': '#ffbb78', 'f29': '#2ca02c', 'f101': '#98df8a', 'f212': '#d62728', 'f245': '#ff9896', 'f100': '#9467bd', 'f89': '#c5b0d5', 'f46': '#8c564b', 'f213': '#c49c94', 'f10': '#e377c2', 'f130': '#f7b6d2', 'f206': '#7f7f7f', 'f27': '#c7c7c7', 'f196': '#bcbd22', 'f210': '#dbdb8d', 'f108': '#17becf', 'f276': '#9edae5', 'f40': '#78C850', 'f220': '#F08030', 'f20': '#6890F0', 'f191': '#A8B820', 'f21': '#A8A878', 'f128': '#A040A0', 'f71': '#F8D030', 'f64': '#E0C068', 'f51': '#EE99AC', 'f169': '#C03028', 'f139': '#F85888', 'f156': '#B8A038', 'f275': '#705898', 'f62': '#98D8D8', 'f154': '#7038F8', 'f55': '#00AFBB', 'f84': '#C4961A', 'f15': '#0000FF', 'f26': '#FFDD44', 'f4': '#00FFFF', 'f52': '#008080', 'f54': '#FF00FF', 'f200': '#f78b2b', 'f173': '#9ff72b', 'f223': '#2b87f7', 'f149': '#6b8eb8', 'f60': '#a26bb8', 'f180': '#d64ca8', 'f271': '#c9d64c', 'f19': '#9fa380', 'f14': '#654957', 'f2': '#ff5005', 'f22': '#a33100', 'f12': '#78C850', 'f135': '#F08030', 'f138': '#6890F0', 'f124': '#A8B820', 'f216': '#A8A878', 'f277': '#A040A0', 'f132': '#F8D030', 'f75': '#E0C068', 'f86': '#EE99AC', 'f11': '#C03028', 'f250': '#F85888', 'f268': '#B8A038', 'f1': '#705898', 'f161': '#98D8D8', 'f106': '#7038F8', 'f81': '#00AFBB', 'f90': '#C4961A', 'f166': '#0000FF'}

    sns.set(rc={'figure.dpi':100, 'savefig.dpi':100, 'figure.figsize':(4,4), 'axes.facecolor':'white', 'figure.facecolor':'white'}, font_scale=1.1) 
    sns.set_style(style='white') 
    fig = plt.gcf() 
    fig.set_size_inches(4,4)
    
    fig_aspect = 1. if  config_o["dataset"] in ["wine", "breast_cancer", "mushroom"] else 1.05
    fig_height = 4 if  config_o["dataset"] in ["wine", "breast_cancer", "mushroom"] else 4 #4

    seab = sns.relplot(data=ranked_features_df,  kind='line', palette=custom_colors, dashes=False, height = fig_height, aspect = fig_aspect,  markers=False) #, legend = True
    
    fontsize = 15 if config_o["dataset"] == "blog" else 15
    seab.set_xlabels("Global Mask as CA", fontsize=fontsize)
    seab.set_ylabels("Ranking", fontsize=fontsize)
#     if config_o["dataset"] == "blog":
#         seab.fig.set_size_inches(4,4)
#     elif config_o["dataset"] == "wine":
#         seab.fig.set_size_inches(8,4)
#     else:
#         seab.fig.set_size_inches(4,4)

    seab._legend.remove()
    
    plt.ylim(11,0)
    plt.xticks([i for i in range(1,len(mask_state_dicts_l)+1)], fontweight='bold', fontsize=15)
    plt.yticks(list(range(1,11)), fontweight='bold', fontsize=15)
    plot_legends = plt.legend(title="Features", bbox_to_anchor=(1, 1), loc='upper left', borderaxespad=0.3, fontsize=13.5, title_fontsize=15, prop={'weight':'bold'})
    for l in plot_legends.get_lines():
            l.set_linewidth(5)
            
    
    xlabels = [r"$\mathbf{g_{"+str(i)+"}}$" for i in range(1,len(mask_state_dicts_l))] + [r"$\mathbf{M_{g}}$"]

    seab.ax.set_xticklabels(xlabels, fontsize=15)
    
    plt.setp(seab.ax.lines,linewidth=5)
    

#     if config_o["dataset"][:3] == "l2x" or config_o["dataset"][:6] == "invase":
#         for t in plot_legends.get_texts():
#             #print(t.__dict__)
#             t.set_text("f"+t._text)





#     seab.fig.set_size_inches(4,4)
#     plt.rcParams['figure.figsize']=(4,4)
    plt.tight_layout() 
    plt.savefig(model._results_path+"/"+config["dataset"]+"_running_rank_seed"+str(seed)+"_global.png")
    
    
    return mg_features, mf_features, instance_features
    
    
    
    ##########################################################################




def main(config, seed=None):
    """Main wrapper function for training routine.

    Args:
        config (dict): Dictionary containing options and arguments.

    """
    # Set directories (or create if they don't exist)
    set_dirs(config)
    # Get data loader for first dataset.
    data_loader = Loader(config, dataset_name=config["dataset"])
    # Add the number of features in a dataset as the first dimension of the model
#     config = update_config_with_model_dims(data_loader, config)
    # Start training and save model weights at the end
    mg_features, mf_features, instance_features = train(config, data_loader, save_weights=True, seed=seed)

    return mg_features, mf_features, instance_features

if __name__ == "__main__":
    # Get parser / command line arguments
    args = get_arguments()
    seeds = [211, 369, 317, 79, 54, 654, 34, 167, 4468, 7817] #770
    mg_dict = {}
    mf_dict = {}
    instance_dict = {}

    for SEED in seeds:
        # Get configuration file
        config = get_config(args)
        config["linear"]=False
        config["relu"]=True
        config["relux2"]=False
        config["relux3"]=False
        config["relux5"]=False
        config["c_hdim"] = 1024
        config["add_noise"]=False

        
        config["noise_type"]= "gaussian_noise"        # Type of noise to add to. Choices: swap_noise, gaussian_noise, zero_out]
        
        
        config["noise_level"] = 0.15
        config["masking_ratio"]= 0.3                # Percentage of the feature to add noise to
        config["n_subsets"]= 2
        config["overlap"] = 0.75

        
        if config["linear"]:
            arc_name = "_linear_seed_"
        elif config["relu"]:
            arc_name = "_linear_relu_seed_"            
        elif config["relux2"]:
            arc_name = "_linear_relux2_seed_"               
        elif config["relux3"]:
            arc_name = "_linear_relux3_seed_"       
        elif config["relux5"]:
            arc_name = "_linear_relux5_seed_"  
            
        # Overwrite the parent folder name for saving results
        config["framework"] = config["dataset"] + arc_name + "class_" + str(config["c_hdim"]) + "_nv_" + str(SEED) + "_gaussian"+str(config["noise_level"])+str(config["masking_ratio"]) + "_" + str(config["n_subsets"])+str(config["overlap"])


        # Get a copy of autoencoder dimensions
        dims = copy.deepcopy(config["dims"])
        # Summarize config and arguments on the screen as a sanity check
        print_config_summary(config, args)
        #----- Run Training - with or without profiler
        mg_features, mf_features, instance_features = main(config, seed=SEED)
        
        mg_dict[str(SEED)] = mg_features
        mf_dict[str(SEED)] = mf_features
        instance_dict[str(SEED)] = instance_features

    # Convert mg ranking to a dataframe
    mg_ranking_df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in mg_dict.items()]))
    mf_ranking_df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in mf_dict.items()]))
    instance_ranking_df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in instance_dict.items()]))

    # Save dataframe as csv file for later use
    mg_ranking_df.to_csv("./results/csvs/" + config["framework"] +"_mg.csv")
    mf_ranking_df.to_csv("./results/csvs/" + config["framework"] +"_mf.csv")
    instance_ranking_df.to_csv("./results/csvs/" + config["framework"] +"_instance.csv")
