import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import linear_model
from matplotlib import cm
from matplotlib.widgets import Slider
from os.path import isfile, join
from utils import utils
from mpl_toolkits import mplot3d
import os

def plot_image(title='My Title', data=None, save_path='./data.png', figsize=(20, 20), constrained_layout=False, cmap='gray', save=False):

    plt.figure(figsize=figsize, constrained_layout=constrained_layout)
    ax1 = plt.subplot(1, 1, 1)
    ax1.imshow(data, cmap=cmap)
    plt.axis('off')
    # ax1.set_title(title)
    if(save):
        plt.draw()
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        #plt.savefig(save_path[0:-3] + 'eps')
        plt.close()
    else:
        plt.show()

def concat_images(imga, imgb):

    ha, wa = imga.shape[:2]
    hb, wb = imgb.shape[:2]
    max_height = np.max([ha, hb])
    total_width = wa + wb
    new_img = np.zeros(shape=(max_height, total_width))
    new_img[:ha, :wa] = imga
    new_img[:hb, wa:wa + wb] = imgb
    return new_img
    
def check_plot_save(path, save, plot):
    
    if save and path is not None:
        plt.savefig(path, bbox_inches='tight', pad_inches=0)

    if plot:
        plt.show()
    else:
        plt.close()

class Plotter:
    def __init__(self, init_dim_x=0, init_dim_y=0):

        self.fig = None
        self.ax = None
        self.scat_samples = None
        self.scat_centers = None
        self.colors = None
        self.labels_samples = None
        self.input_dim = None
        self.init_dim_x = init_dim_x
        self.init_dim_y = init_dim_y
        self.dimx = None
        self.dimy = None

    def update(self):
        self.init_dim_x = int(self.dimx.val)
        self.init_dim_y = int(self.dimy.val)

    def press(self, event):
        if event.key == 'right':
            self.init_dim_x = self.init_dim_x + 1
        elif event.key == 'left':
            self.init_dim_x = self.init_dim_x - 1
        elif event.key == 'up':
            self.init_dim_y = self.init_dim_y + 1
        elif event.key == 'down':
            self.init_dim_y = self.init_dim_y - 1

        if self.init_dim_x < 0:
            self.init_dim_x = 0
        if self.init_dim_x >= self.input_dim:
            self.init_dim_x = self.input_dim - 1
        if self.init_dim_y < 0:
            self.init_dim_y = 0
        if self.init_dim_y >= self.input_dim:
            self.init_dim_y = self.input_dim - 1

        if event.key == 'right' or event.key == 'left':
            self.dimx.set_val(self.init_dim_x)
        elif event.key == 'up' or event.key == 'down':
            self.dimy.set_val(self.init_dim_y)

        self.fig.canvas.draw_idle()

    def plot_data(self, data, target, centers=None, relevances=None, pause_time=0.01, print_labels=False):
        self.input_dim = data.shape[-1]

        if self.fig is None:
            plt.ion()
            self.fig, self.ax = plt.subplots(figsize=(10, 7))
            self.ax.set_xlabel('Dim {}'.format(self.init_dim_x), fontsize=15)
            self.ax.set_ylabel('Dim {}'.format(self.init_dim_y), fontsize=15)
            self.ax.grid(True)

            axdimx = plt.axes([0.2, 0.009, 0.65, 0.03], facecolor='lightgoldenrodyellow')
            axdimy = plt.axes([0.007, 0.17, 0.03, 0.65], facecolor='lightgoldenrodyellow')

            self.dimx = Slider(axdimx, 'Dim x', int(0), int(self.input_dim - 1), valfmt="%1.0f",
                               valinit=self.init_dim_x,
                               valstep=1, orientation='horizontal')
            self.dimy = Slider(axdimy, 'Dim y', int(0), int(self.input_dim - 1), valfmt="%1.0f",
                               valinit=self.init_dim_y,
                               valstep=1, orientation='vertical')

            self.dimx.on_changed(self.update)
            self.dimy.on_changed(self.update)
            self.fig.canvas.mpl_connect('key_press_event', self.press)

            colors = (2 * np.pi) * target / (target.max() * 2)
            self.scat_samples = self.ax.scatter(data[:, self.init_dim_x], data[:, self.init_dim_y], c=colors,
                                                cmap='hsv', alpha=0.5)

            self.labels_samples = []
            if print_labels:
                for i, label in enumerate(target):
                    self.labels_samples.append(self.ax.annotate(label,
                                                                (data[:, self.init_dim_x][i],
                                                                 data[:, self.init_dim_y][i]),
                                                                horizontalalignment='center',
                                                                verticalalignment='center',
                                                                size=11))

            if centers is not None and relevances is not None:
                self.scat_centers = self.ax.errorbar(centers[:, self.init_dim_x], centers[:, self.init_dim_y],
                                                     xerr=relevances[:, self.init_dim_x],
                                                     yerr=relevances[:, self.init_dim_y],
                                                     alpha=0.5, fmt='o', c='k')
            plt.show()
        else:
            self.ax.set_xlabel('Dim {}'.format(self.init_dim_x), fontsize=15)
            self.ax.set_ylabel('Dim {}'.format(self.init_dim_y), fontsize=15)
            colors = (2 * np.pi) * target / (target.max() * 2)
            self.scat_samples.remove()
            self.scat_samples = self.ax.scatter(data[:, self.init_dim_x], data[:, self.init_dim_y], c=colors,
                                                cmap='hsv', alpha=0.5)

            if print_labels:
                for i, label in enumerate(target):
                    self.labels_samples[i].remove()
                    self.labels_samples[i] = self.ax.annotate(label,
                                                              (
                                                                  data[:, self.init_dim_x][i],
                                                                  data[:, self.init_dim_y][i]),
                                                              horizontalalignment='center',
                                                              verticalalignment='center',
                                                              size=11)

            if centers is not None and relevances is not None:
                self.scat_centers.remove()
                self.scat_centers = self.ax.errorbar(centers[:, self.init_dim_x], centers[:, self.init_dim_y],
                                                     xerr=relevances[:, self.init_dim_x],
                                                     yerr=relevances[:, self.init_dim_y],
                                                     alpha=0.5, fmt='o', c='k')
        plt.waitforbuttonpress(timeout=pause_time)

    def plot_hold(self, time=10000):
        plt.pause(time)


class HParams:
    def __init__(self):
        self.fig = None
        self.image_path = "plots/"
        if not os.path.isdir(self.image_path):
            os.mkdir(self.image_path)

    def plot_hparams(self, tag, p_values, metrics, save=False, plot=False):
        return self.plot_x_y(p_values, metrics, tag, save=save, plot=plot)

    # ------------------------------------------------------------------- #
    # ---------------------- Parameters Graphs -------------------------- #

    def plot_params_results(self, file_name, header_rows=5, params_to_plot=None, save=False, plot=True):

        datasets, n_folds, _, metrics = utils.read_header([file_name], "", header_rows, save_parameters=False)
        params, results = utils.read_params_and_results(file_name, header_rows)

        if params_to_plot is None:
            params_to_plot = params.columns

        for param in params_to_plot:
            if "seed" in param:
                continue

            for dataset in datasets:
                matching = [result for result in results.columns if dataset in result]

                if n_folds > 1:
                    for metric in metrics:
                        matching_metrics = [local_result for local_result in matching if metric in local_result]

                        x, y = np.array([]), np.array([])

                        for result in results[matching_metrics].columns:
                            x = np.append(x, params[param].values.astype(np.float), axis=0)
                            y = np.append(y, results[result].values.astype(np.float), axis=0)

                        self.plot_x_y(x, y, "{0} - {1}".format(param, dataset),
                                      save=save, plot=plot, limits=metric!="n_nodes")

                else:
                    for result in results[matching].columns:
                        x = params[param].values.astype(np.float)
                        y = results[result].values.astype(np.float)

                        self.plot_x_y(x, y, "{0} - {1}".format(param, result),
                                      save=save, plot=plot, limits="n_nodes" not in result)

    def plot_x_y(self, x, y, title, marker="o", color='b', font_size=12, save=False, plot=False, limits=True):
        self.fig, ax = plt.subplots()
        ax.yaxis.grid()
        if limits:
            ax.set_ylim([0, 1])

        x = x.astype(float)
        y = y.astype(float)

        plt.rc('font', family='serif')
        plt.title(title, fontsize=font_size)
        plt.plot(x, y, marker, color=color, clip_on=False)
        if limits:
            plt.yticks(np.linspace(0, 1, num=11))

        self.plot_fit_linear(plt, x, y)

        plt.tight_layout(pad=0.2)
        check_plot_save(path=os.path.join(self.image_path, "{}.png".format(title)), save=save, plot=plot)

        return self.fig

    def plot_fit_linear(self, to_plot, x, y):
        # Create linear regression object
        regr = linear_model.LinearRegression()

        # Train the model using the training sets
        regr.fit(x.reshape(-1, 1), y.reshape(-1, 1))

        # Make predictions using the testing set
        fit = regr.predict(x.reshape(-1, 1))

        to_plot.plot(x, fit, color='r', clip_on=False, linewidth=6)

    def plot_tensorboard_x_y(self, parameters, metric_name, metric_values, writer, dataset, save=False, plot=False):
        for param, p_values in parameters.iteritems():
            if param == 'seed' or param == 'Index':
                continue

            figure = self.plot_hparams(param, p_values.values, metric_values, save=save, plot=plot)
            writer.add_figure(param + '/' + metric_name + '_' + dataset, figure)


class PlotConfusionMatrix:

    def save_cm(self, save_path, confusion_matrix, confusion_matrix_lines):
        confusion_matrix = confusion_matrix[0:confusion_matrix_lines]
        confusion_matrix = confusion_matrix.astype(int)

        plt.figure(figsize=(50, 20))
        sns.set(font_scale=2.5)

        ax = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
        plt.yticks(rotation=0)
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position('top')

        plt.savefig(save_path)
        plt.savefig(save_path[0:-3] + 'eps')
        plt.cla()
        plt.close('all')

    def save_txt(self, save_path, confusion_matrix):
        np.savetxt(save_path, confusion_matrix, fmt='%i', delimiter=' , ')


class PlotLossLandscape:

    def plot_loss_landscape_interpolation(self, loss_data, steps=40, title='Linear Interpolation of Loss', 
                    xlabel='Interpolation Coefficient', ylabel='Loss', plot=True, save_plots=False, folder="./",
                    save_path="./plot_loss_landscape_interpolation.png"):
        with sns.color_palette("husl", 8):
            plt.plot([1 / steps * i for i in range(steps)], loss_data)
            plt.title(title)
            plt.xlabel(xlabel)
            plt.ylabel(ylabel)
            axes = plt.gca()
            # axes.set_ylim([2.300,2.325])
        self.apply_output_options(plot, save_plots, save_path)

    def plot_loss_landscape_2d(self, loss_data_fin, levels=50, title='Loss Contours around Trained Model',
                            plot=True, save_plots=False, save_path="./plot_loss_landscape_2d.png"):
        plt.contour(loss_data_fin, levels=levels)
        plt.title(title)
        self.apply_output_options(plot, save_plots, save_path)

    def plot_loss_landscape_3d(self, loss_data_fin, projection='3d', steps=40, title='Surface Plot of Loss Landscape',
                                plot=True, save_plots=False, save_path="./plot_loss_landscape_3d.png"):
        fig = plt.figure()
        ax = plt.axes(projection=projection)
        X = np.array([[j for j in range(steps)] for i in range(steps)])
        Y = np.array([[i for _ in range(steps)] for i in range(steps)])
        surf = ax.plot_surface(X, Y, loss_data_fin, cmap=cm.coolwarm, linewidth=0, antialiased=True)
        fig.colorbar(surf, shrink=0.5, aspect=5)
        ax.set_title(title)
        fig.show()
        self.apply_output_options(plot, save_plots, save_path)

    def apply_output_options(self, plot, save_plots, save_path):
        if save_plots:
            plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        if plot:
            plt.show()
        else:
            plt.close()


if __name__ == '__main__':
    HParams().plot_params_results("../results/fashion_r2/baseline_fashion_mnist_autoencoder30_r13_b256_0.csv",
                                  save=False, plot=True, params_to_plot=["at"])
    '''
    dataset_to_metrics("cm:pur:nmi:ce", "../exp_zero_loss_only_som/", "mnist", "../raw-datasets/",
                        "../mnist_xyz_results",
                        param_file="../arguments/autoencoder_50p.lhs")
    '''
