import matplotlib
import matplotlib.pyplot as plt
# matplotlib.use('Agg')
import numpy as np

blue = '#0254e0'
teal = '#00d4b4'
dark_blue = '#35619c'
red = '#fa0f32'
pink = '#e305ba'

def plot_histories(histories, model_names, metric, save=None, validation=True, extracted=False, y_lims=[35_000, 46_000]):
    plt.xlabel('Epochs')
    plt.ylabel(metric)
    legend = []
    for i in range(len(histories)):
        history = histories[i]
        model_name = model_names[i]

        if not extracted:
            history = history.history

        plt.plot(history[metric])
        legend.append(model_name + ' ' + metric)

        if validation:
            plt.plot(history['val_' + metric])
            legend.append(model_name + ' val_' + metric)

    plt.legend(legend)
    plt.ylim(y_lims[0], y_lims[1])

    if save is not None:
        plt.savefig(save)

    plt.show()
    plt.close("all")


# A function for plotting images
def show_image(image):
    plt.imshow(image)
    plt.show()


# Function that recieves two set of lists of histories from differnt seeds and plots the mean and variance of them
# the variance is plotted as a shaded area and the mean as a line.
# for example if we have 2 models, each with 3 seeds, and run for 20 epochs, we will have 2 set of 3 lists of 20 values
# the function will plot the mean of the 2 models and the variance of the 2 models all in the same plot
# to plot the variance as shaded area, we use the fill_between function of matplotlib
def plot_histories_seeds(histories, model_names, seeds, metric, save=None, validation=True, extracted=False, y_lims=[1.0, 1.15], legend_on=True):
    # histories parameter is of the following specifications:
    # histories = {seed1: [model1_history, model2_history, ...], seed2: [model1_history, model2_history, ...], ...}
    # where model{i}_history = {'loss': [loss1, loss2, ...], 'val_loss': [val_loss1, val_loss2, ...], ...}
    # so in order to access the loss of model_i in seed_j, we should use histories[seeds[j]][i]['loss']
    plt.xlabel('Epochs')
    plt.ylabel(metric.capitalize())
    legend = []

    # parameters for calculating standard deviation
    N = len(seeds)
    epochs = len(histories[seeds[0]][0][metric])


    final_mean_array = []
    final_std_array = []
    final_ci_array = []
    final_val_mean_array = []
    final_val_std_array = []
    final_val_ci_array = []
    for model_index in range(len(model_names)):
        mean = np.zeros(shape=(N, epochs))
        std = np.zeros(shape=(N, epochs))
        val_mean = np.zeros(shape=(N, epochs))
        val_std = np.zeros(shape=(N, epochs))
        for seed in seeds:

            history = histories[seed][model_index]
            model_name = model_names[model_index]

            if not extracted:
                history = history.history

            mean[seeds.index(seed), :] = history[metric]
            std[seeds.index(seed), :] = history[metric]

            if validation:
                val_mean[seeds.index(seed), :] = history['val_' + metric]
                val_std[seeds.index(seed), :] = history['val_' + metric]

        final_mean_array.append(np.mean(mean, axis=0))
        final_std_array.append(np.std(std, axis=0))
        final_val_ci_array.append(np.std(std, axis=0) * 1.96 / np.sqrt(N))

        if validation:
            final_val_mean_array.append(np.mean(val_mean, axis=0))
            final_val_std_array.append(np.std(val_std, axis=0))
            final_ci_array.append(np.std(val_std, axis=0) * 1.96 / np.sqrt(N))

    color_list = [[teal, blue], [red, pink]]
    linestyle = ['dashed', 'solid']
    for i in range(len(model_names)):
        model_name = model_names[i]
        plt.plot(final_mean_array[i], color=color_list[i][0], linestyle=linestyle[i])
        legend.append(model_name + ' ' + metric)
        plt.fill_between(range(epochs),
                         final_mean_array[i] - final_ci_array[i],
                         final_mean_array[i] + final_ci_array[i],
                            color=color_list[i][0],
                         alpha=0.1)
        # legend.append(model_name + ' ' + metric + '_95%_ci')
        legend.append('_nolegend_')

        if validation:
            plt.plot(final_val_mean_array[i], color=color_list[i][1], linestyle=linestyle[i])
            legend.append(model_name + ' val_' + metric)
            plt.fill_between(range(epochs),
                             final_val_mean_array[i] - final_val_ci_array[i],
                             final_val_mean_array[i] + final_val_ci_array[i],
                             color=color_list[i][1],
                             alpha=0.1)
            # legend.append(model_name + ' val_' + metric + '_95%_ci')
            legend.append('_nolegend_')


    if legend_on:
        plt.legend(legend)
    plt.ylim(y_lims[0], y_lims[1])

    plt.tight_layout()

    if save is not None:
        plt.savefig(save)

    plt.show()
    plt.close("all")


