

import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.datasets import load_diabetes, load_iris, load_breast_cancer, load_wine, load_boston, fetch_california_housing, fetch_kddcup99, fetch_openml
from utils.colors import get_color_list2

sns.set_style(style='white') 
sns.set(rc={'figure.dpi':100, 'savefig.dpi':100, 'figure.figsize':(3,2), 'axes.facecolor':'white', 'figure.facecolor':'white'}, font_scale=1.1) 


def change_width(ax, new_value) :
    for patch in ax.patches :
        current_width = patch.get_width()
        diff = current_width - new_value

        # we change the bar width
        patch.set_width(new_value)

        # we recenter the bar
        patch.set_x(patch.get_x() + diff * .5)


def save_loss_plot(losses, plots_path):
    """Saves loss plot

    Args:
        losses (dict): A dictionary contains list of losses
        plots_path (str): Path to use when saving loss plot

    """
    x_axis = list(range(len(losses["tloss_e"])))
    plt.plot(x_axis, losses["tloss_e"], c='r', label="Training")
    title = "Training"
    if len(losses["vloss_e"]) >= 1:
        # If validation loss is recorded less often, we need to adjust x-axis values by the factor of difference
        beta = len(losses["tloss_e"]) / len(losses["vloss_e"])
        x_axis = list(range(len(losses["vloss_e"])))
        # Adjust the values of x-axis by beta factor
        x_axis = [beta * i for i in x_axis]
        plt.plot(x_axis, losses["vloss_e"], c='b', label="Validation")
        title += " and Validation "
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(loc="upper right")
    plt.title(title + " Loss", fontsize=12)
    plt.tight_layout()
    plt.savefig(plots_path + "/loss.png")




def plot_global_feature(data, mask, labels, options, mode='train'):
    
    if options["dataset"] == "mnist":
        feature_importance = plot_grid_mnist_global(data, mask, labels, options, mode=mode)
        _ = plot_grid_mnist_global2(data, mask, labels, options, mode=mode)
    else:
        feature_importance = plot_grid_other_global(data, mask, labels, options, mode=mode)
        
    return feature_importance


def plot_grid(data, mask, labels, options, mode='train'):
    
    if options["dataset"] == "mnist":
        plot_grid_mnist(data, mask, labels, options, mode=mode)
        plot_grid_mnist2(data, mask, labels, options, mode=mode)
    else:
        feature_ranking = plot_grid_other(data, mask, labels, options, mode=mode)
    return feature_ranking



def convert_to_arr(x):
    x = np.array(x)
    return x.reshape(1,-1)

    

def plot_grid_other(data, mask, labels, options, mode='train'):
    

    
    fig, ax = plt.subplots()
    
    data_l = data[0,:].reshape(-1,).tolist()
    mask_l = mask[0,:].reshape(-1,).tolist()
    label = labels[0]


    modified_labels_raw = ["f"+str(i) for i in range(1, len(data_l)+1)]
    modified_labels_raw = options["invase_features"] if options["dataset"] == "invase" else modified_labels_raw

    modified_labels = [str(x) for _, x in sorted(zip(mask_l, modified_labels_raw), reverse=True)]
    mask_l = sorted(mask_l, reverse=True)
    
    data_arr = convert_to_arr(data_l)
    mask_arr = convert_to_arr(mask_l)

    # Chose top-10
    mask_l = mask_l[:min(10, len(mask_l))]
    modified_labels = modified_labels[:min(10, len(modified_labels))]

    df = pd.DataFrame({'Features':modified_labels, 
                       'Importance weight': mask_l})
    

    
    pkmn_type_colors = get_color_list2()
    custom_colors = {f: pkmn_type_colors[i] for i, f in enumerate(modified_labels_raw)}
    

    plt.gcf().set_size_inches(3, 3)
    ax = sns.barplot(y='Features',x='Importance weight', data=df, palette=custom_colors, alpha=1.0)
    ax.set_xlabel("", fontsize=15) #Importance weight
    ax.set_ylabel("Features", fontsize=15)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    xlabels = ['{:,.2f}'.format(x) if x!=0 else 0 for x in ax.get_xticks()]
    ax.set_xticklabels(xlabels)
    ax.tick_params(axis="y",labelrotation=20,  pad=-5, labelsize=13.5)
    
    plt.savefig(f"./results/{options['framework']}/evaluation/{mode}_grid_{options['seed']}.png", bbox_inches='tight')
    plt.clf()
    
    return modified_labels
    


def plot_grid_other_global(data, mask, labels, options, mode='train'):
    
    
    plt.figure()
    
    ax = plt.axes()

    data_l = np.mean(data, axis=0).tolist()
    mask_l = np.mean(mask, axis=0).tolist()
    
        

    modified_labels_raw = ['f' + str(i) for i in range(1, len(data_l)+1)]
#     modified_labels = options["invase_features"] if options["dataset"] == "invase" else modified_labels

    modified_labels = [str(x) for _, x in sorted(zip(mask_l, modified_labels_raw), reverse=True)]
    mask_l = sorted(mask_l, reverse=True)
    
    data_arr = convert_to_arr(data_l)
    mask_arr = convert_to_arr(mask_l)

    
    # Chose top-10
    mask_l = mask_l[:min(10, len(mask_l))]
    modified_labels = modified_labels[:min(10, len(modified_labels))]

    df = pd.DataFrame({'Features':modified_labels, 
                       'Importance weight': mask_l})
    
    print(modified_labels)
    pkmn_type_colors = get_color_list2()
   
                
    if options["dataset"][:3] == "l2x":
        ranked_features = ['f'+str(i) for i in range(1, 11)]
        custom_colors = {f: pkmn_type_colors[i] for i, f in enumerate(ranked_features)}
    else:
        ranked_features = modified_labels_raw
        custom_colors = {ranked_features[i]: pkmn_type_colors[i] for i in range(len(ranked_features))}


    plt.gcf().set_size_inches(3, 3)
    ax = sns.barplot(y='Features',x='Importance weight', data=df, palette=custom_colors, alpha=1.0)
    ax.set_xlabel("", fontsize=15) #Importance weight
    ax.set_ylabel("Features", fontsize=15)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    xlabels = ['{:,.2f}'.format(x) if x!=0 else 0 for x in ax.get_xticks()]
    ax.set_xticklabels(xlabels)
    ax.tick_params(axis="y",labelrotation=20,  pad=-5, labelsize=13.5)
    
    plt.savefig(f"./results/{options['framework']}/evaluation/{mode}_grid_{options['seed']}.png", bbox_inches='tight')
    plt.clf()
    
    
    
    return modified_labels





def plot_grid_mnist_global(data, mask, labels, options, mode='train'):

    labels = np.array(labels)
    label_ls = ["g"] + [str(i) for i in range(options["n_classes"])]

    data1 = np.mean(data, axis=0).reshape(28,28)
    mask1 = np.mean(mask, axis=0).reshape(28,28)

    overlay_weight=5
    data2 = (1-mask1)*data1 + overlay_weight * mask1
    data2 = data2/np.max(data2)
    
    class_label = None
    
    for k in range(options["n_classes"]+1):
        plt.figure()
        ax = plt.axes()

        f,(ax1,ax2,ax3) = plt.subplots(1,3, 
                    gridspec_kw={'width_ratios':[1,1,1]})

        ax1.get_shared_y_axes().join(ax2)

        g1 = sns.heatmap(data1,cmap="gist_gray_r",cbar=False,ax=ax1, 
                    xticklabels = [], yticklabels = [])
        g1.set_ylabel('')
        g1.set_xlabel('')
        
        
        g2 = sns.heatmap(mask1,cmap="gist_gray_r",cbar=False,ax=ax2, 
                    xticklabels = [], yticklabels = [])
        g2.set_ylabel('')
        g2.set_xlabel('')
        g2.set_yticks([])

        g3 = sns.heatmap(data2,cmap="gist_gray_r",cbar=True,ax=ax3, 
                    xticklabels = [], yticklabels = [])
        g3.set_ylabel('')
        g3.set_xlabel('')
        g3.set_yticks([])
        
        
        plt.savefig(f"./results/{options['framework']}/evaluation/{mode}_{label_ls[k]}.png")
        plt.clf()
        
        
        labels_bool = labels == (k)
        l_index = np.where(labels_bool)[0]
        
        data1 = data[l_index,:]
        mask1 = mask[l_index,:]

        
        data1 = np.mean(data1, axis=0).reshape(28,28)
        mask1 = np.mean(mask1, axis=0).reshape(28,28)
        
        # Modify data1 to overlay mask1 over it
        data2 = (1-mask1)*data1 + overlay_weight * mask1
        data2 = data2/np.max(data2)
        


def plot_grid_mnist(data, mask, labels, options, mode='train'):
    plt.figure()
    
    ax = plt.axes()

    data1 = data[0,:].reshape(28,28)
    mask1 = mask[0,:].reshape(28,28)

    data2 = (1-mask1)*data1 + 10*mask1
    data2 = data2/np.max(data2)
    
    f,(ax1,ax2, ax3) = plt.subplots(1,3, 
                gridspec_kw={'width_ratios':[1,1,1]})
    
    ax1.get_shared_y_axes().join(ax2)
    
    g1 = sns.heatmap(data1,cmap="gist_gray_r",cbar=False,ax=ax1, 
                xticklabels = [], yticklabels = [])
    g1.set_ylabel('')
    g1.set_xlabel('')
    g2 = sns.heatmap(mask1,cmap="gist_gray_r",cbar=False,ax=ax2, 
                xticklabels = [], yticklabels = [])
    g2.set_ylabel('')
    g2.set_xlabel('')
    g2.set_yticks([])
    
    g3 = sns.heatmap(data2,cmap="gist_gray_r",cbar=True,ax=ax3, 
                     xticklabels = [], yticklabels = [])
    g3.set_ylabel('')
    g3.set_xlabel('')
    g3.set_yticks([])

    
    plt.savefig(f"./results/{options['framework']}/evaluation/{mode}_grid_{options['seed']}.png")
    plt.clf()
    
    return data2




    


