#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys, os
import random
from matplotlib.axes import Axes
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from typing import List
from scipy.optimize import fsolve
from matplotlib.patches import Ellipse, Patch

sys.path.append(os.pardir)
from seds.gmm_fitting import diff, find_limits


class PlotConfigs:
    """Hardcoded plot configurations.
    """

    COLORS = ["darkblue", "darkorange", "darkgreen", "purple", "salmon",
              "brown", "tan", "black", "cyan"]
    FMTS = ['d--', 'o-', 's:', 'x-.', '*-', 'd--', 'o-']

    FIGURE_SIZE = (12, 8)
    FIGURE_DPI = 120

    TICKS_SIZE = 16
    LABEL_SIZE = 18
    LEGEND_SIZE = 18
    TITLE_SIZE = 18


def plot_trajectory(trajectory: np.ndarray, title: str = "Reference Trajectories"):
    """ Plot a given trajectory based on dimension.

    Args:
        trajectory (np.ndarray): Trajectory in form of a numpy array.
    """

    plt.figure(figsize=PlotConfigs.FIGURE_SIZE, dpi=PlotConfigs.FIGURE_DPI)

    x_1 = trajectory[:,0]
    x_2 = trajectory[:,1]
    plt.scatter(x_1, x_2, marker='o', s=3)
    plt.xlabel("X1", fontsize=PlotConfigs.LABEL_SIZE)
    plt.ylabel("X2", fontsize=PlotConfigs.LABEL_SIZE)

    plt.title(title, fontsize=PlotConfigs.TITLE_SIZE)
    plt.grid()
    plt.savefig('trajectory', dpi=PlotConfigs.FIGURE_DPI, bbox_inches='tight')
    plt.show()


def plot_bic_scores(bics: List, n_components: int):
    """Plotting bic scores for GMM fit.

    Args:
        bics (List): List of bic scores.
        n_components (int): Number of Gaussian functions.
    """
    diff1 = [0] + diff(bics)
    diff2 = [0] + diff(diff1)

    _, axs = plt.subplots(2)
    axs[0].plot(n_components, bics, label = "BIC")
    axs[0].set_title("BIC Score for GMM fit")
    axs[0].grid()
    axs[0].set_xlabel("Number of Gaussian Functions")
    axs[0].legend()

    axs[1].plot(n_components, diff1, label="diff1(BIC)")
    axs[1].plot(n_components, diff2, label="diff2(BIC)")
    axs[1].set_title("First and Second Derivative for GMM fit")
    axs[1].grid()
    axs[1].set_xlabel("Number of Gaussian Functions")
    axs[1].legend()


    plt.tight_layout()
    plt.show()


def plot_gmm(trajectory: np.ndarray, means: List, covariances: List):
    """ This function plots the covariance and mean of the components of the GMM on the reference
        trajectory.

    Example:
        plot_gmm(trajectory=positions_py, means=gmm_sine.means_,
                covariances=gmm_sine.covariances_)

    Args:
        trajectory (np.ndarray): The reference trajectory.
        means (List): List of mean parameters for Gaussian models.
        covariances (List): List of covariance parameters for Gaussian models.
    """

    # generate the ellipses for gmm components

    ellipses = []
    for i in range(len(means)):
        v, w = np.linalg.eigh(covariances[i])
        v = 2. * np.sqrt(2.) * np.sqrt(v)
        u = w[0] / np.linalg.norm(w[0])
        angle = np.arctan(u[1] / u[0])
        angle = 180. * angle / np.pi
        e = Ellipse(means[i], v[0], v[1], 180. + angle)
        ellipses.append(e)

    # plot the trajectory
    _, ax = plt.subplots(figsize=PlotConfigs.FIGURE_SIZE, dpi=PlotConfigs.FIGURE_DPI)
    X1 = trajectory[:, 0]
    X2 = trajectory[:, 1]
    plt.scatter(X1, X2, marker='o', s=5)

    # plot the means
    for mean in means:
        plt.plot([mean[0]], [mean[1]], marker = 'x', markersize = 8, color='red')

    # plot the ellipses
    for ell in ellipses:
        ax.add_artist(ell)
        ell.set_clip_box(ax.bbox)
        ell.set_alpha(0.6)
        ell.set_facecolor(np.random.rand(3))

    x_min,x_max,y_min,y_max = find_limits(trajectory)
    ax.set_xlim(x_min * 0.9, x_max * 1.1)
    ax.set_ylim(y_min * 0.9, y_max * 1.1)

    plt.grid()
    plt.xlabel('X1', fontsize=16)
    plt.ylabel('X2', fontsize=16)
    plt.show()


def plot_ds(ds, trajectory: np.ndarray, title: str = None, scale_factor: float = 1.2,
    space_stretch: float = 5, width: float = 0.15, frequency: int = 15,
    show_grid_arrows: bool = True, file_name: str = "",
    show_data_arrows: bool = True, save_dir: str = ""):
    """ Plot a dynamical system and its vector maps.

    TODO: Mark the start and end of trajectories!

    Args:
        ds (PlanningPolicyInterface): A dynamical system for motion generation task.
        trajectory (np.ndarray): Input trajectory array (n_samples, dim).
        title (str, optional): Title of the plot. Defaults to None.
        scale_factor (float, optional): Amount of scaling for velocity arrows. Defaults to 2.
        space_stretch (float, optional): How much of the entire space to show in vector map.
            Defaults to 1.

        frequency (int, optional): Frequency of plotting vectors. Increase to yield more
            arrows in the plot. Defaults to 20.

        show_grid_arrows (bool, optional): Whether to show the grid arrows. Defaults to False.
        show_data_arrows (bool, optional): Whether to show arrows for data points. Defaults to True.
        save_dir(str, optional): Provide a save directory for the figure. Leave empty to
            skip saving. Defaults to "".

        file_name(str, optional): Name of the plot file. Defaults to "".
    """

    # find trajectory limits
    x_min, x_max, y_min, y_max = find_limits(trajectory)

    # calibrate the axis
    plt.figure(figsize=PlotConfigs.FIGURE_SIZE, dpi=PlotConfigs.FIGURE_DPI)
    axes = plt.gca()
    axes.set_xlim([x_min - space_stretch, x_max + space_stretch])
    axes.set_ylim([y_min - space_stretch, y_max + space_stretch])

    # plot the trajectory
    plt.grid()
    trimed_trajectory = np.array(random.choices(trajectory, k=len(trajectory)))
    x = trimed_trajectory[:, 0]
    y = trimed_trajectory[:, 1]
    plt.scatter(x, y, color='blue', marker='o', s=5, label='Expert Demonstrations')

    # generate the grid data
    x_interval = np.linspace(x_min - space_stretch, x_max + space_stretch, frequency)
    y_interval = np.linspace(y_min - space_stretch, y_max + space_stretch, frequency)
    grid_coordinates = []
    for x in x_interval:
        for y in y_interval:
            grid_coordinates.append([x, y])
    grid_coordinates = np.array(grid_coordinates)

    # generate the traj data
    traj_coordinates = np.array(random.choices(trajectory, k=2*frequency))

    # plot the grid velocities
    if show_grid_arrows:
        res = ds.predict(grid_coordinates)

        for idx, point in enumerate(grid_coordinates):
            norm = np.linalg.norm(res[idx])
            scale = scale_factor * norm

            plt.arrow(point[0], point[1], res[idx][0] / scale, res[idx][1] / scale,
                width=width, color='green', label='Policy Action (Unknown Region)')

    # plot the grid velocities
    if show_data_arrows:
        res = ds.predict(traj_coordinates)

        for idx, point in enumerate(traj_coordinates):
            norm = np.linalg.norm(res[idx])
            scale = scale_factor * norm * 0.5

            plt.arrow(point[0], point[1], res[idx][0] / scale, res[idx][1] / scale,
                width=width, color='red', label='Policy Action (On Trajectories)')

    plt.xlabel('X1', fontsize=PlotConfigs.LABEL_SIZE)
    plt.ylabel('X2', fontsize=PlotConfigs.LABEL_SIZE)
    plt.tick_params(axis='both', which='both', labelsize=PlotConfigs.TICKS_SIZE)

    if title is not None:
        plt.title(title, fontsize=20)
    plt.grid()

    # Manually create a legend with a custom handle
    green_arrows = plt.Line2D([0], [0], color='green', linestyle='-',
        label='Policy Action (Unknown Region)')
    red_arrows = plt.Line2D([0], [0], color='red', linestyle='-',
        label='Policy Action (On Trajectories)')
    blue_dots = plt.Line2D([0], [0], color='blue', marker='o',
        label='Expert Demonstrations')

    # Add legend with the custom handle
    plt.legend(fontsize=PlotConfigs.LEGEND_SIZE, loc='upper right',
        handles=[green_arrows, red_arrows, blue_dots])

    if save_dir != "":
        name = file_name if file_name != "" else 'plot'
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(os.path.join(save_dir, name), dpi=PlotConfigs.FIGURE_DPI, bbox_inches='tight')
    else:
        plt.show()


def multi_curve_plot_errorband(xs: List[str] or np.ndarray, y_means: List[np.ndarray],
        y_vars: List[np.ndarray], legends: List[str] = None, xlabel: str = "X",
        std_exaggeration: float = 2.0, ylabel: str = "Y",
        file_name: str = "", save_dir: str = "", use_boxes: bool = True,
        column_space: float = 10, inter_column_space: float = 1.2, log: bool = False):
    """ Plot multiple curves with errorbands.

    # TODO: Switch to datasamples instead of mean/var composition.
    # TODO: Messy function close to the deadline! Refactor later.

    Args:
        xs (List[str] or np.ndarray): Values for the xaxis.
        y_means (List[np.ndarray]): Mean values for the yaxis.
        y_varboths (List[np.ndarray]): Variance of yaxis.
        legends (List[str], optional): Legends corresponding to y_means. Defaults to None.
        xlabel (str, optional): xaxis label. Defaults to "X".
        ylabel (str, optional): yaxis label. Defaults to "Y".
        save_dir(str, optional): Provide a save directory for the figure. Leave empty to
            skip saving. Defaults to "".

        file_name(str, optional): Name of the plot file. Defaults to "".
    """

    plt.figure(figsize=PlotConfigs.FIGURE_SIZE, dpi=PlotConfigs.FIGURE_DPI)
    axes = plt.gca()

    idx: int = 0
    violins: List = []
    for y_mean, y_var in zip(y_means, y_vars):
        if not use_boxes:
            plt.errorbar(x=xs, y=y_mean, yerr=std_exaggeration * y_var,
                color=PlotConfigs.COLORS[idx], label=legends[idx],
                fmt=PlotConfigs.FMTS[idx], capsize=5, elinewidth=2, markeredgewidth=3, linewidth=2)
        else:
            violins.append(axes.violinplot([np.random.normal(np.log(y_m) if log else y_m,
                                                             y_v * std_exaggeration,
                                                             size=100) \
                                            for y_m, y_v in zip(y_mean, y_var)],
                            positions=[(column_space * pos + idx * inter_column_space) for pos in range(1, 8 + 1)], widths=2.5, showmeans=True))

            for vp in violins[-1]['bodies']:
                vp.set_facecolor(PlotConfigs.COLORS[idx])
                vp.set_alpha(0.4)
                vp.set_linewidth(2)
        idx += 1

    axes.set_ylabel(ylabel, fontsize=PlotConfigs.LABEL_SIZE)
    axes.set_xlabel(xlabel, fontsize=PlotConfigs.LABEL_SIZE)

    if use_boxes:
        for x in [(column_space * (pos + 1/2) + (idx / 2) * inter_column_space) \
                            for pos in range(0, 8 + 1)]:
            axes.axvline(x, color = 'gray', linestyle='dashed', linewidth=1)

        axes.set_xticks([(column_space * pos + (idx / 2) * inter_column_space) \
                        for pos in range(1, 8 + 1)], labels=xs)
    plt.tick_params(axis='both', which='both', labelsize=PlotConfigs.TICKS_SIZE)

    if use_boxes:
        plt.grid(axis='y', linestyle='dashed')
        legend_handles = [
            Patch(facecolor=PlotConfigs.COLORS[idx], edgecolor='black') \
                for idx in range(len(legends))
        ]
        legend_labels = legends
        plt.legend(legend_handles, legend_labels, loc='upper left', fontsize=PlotConfigs.LEGEND_SIZE - 2, ncol=1)
    else:
        plt.grid(axis='both', linestyle='dashed')
        plt.legend(loc='upper right', fontsize=PlotConfigs.LEGEND_SIZE - 2)

    if save_dir != "":
        name = file_name if file_name != "" else 'plot'
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(os.path.join(save_dir, name), dpi=PlotConfigs.FIGURE_DPI, bbox_inches='tight')

    plt.show()


def plot_performance_curves(file_path: str = "../../res/linear_direct_transfer.npy",
    keys: List[str] = ["linear"], num_execs: int = 100):
    """ Plotting the num_iters or time for transfer ds learning.

    Args:
        file_path (str, optional): Path of the data file. Defaults to
            "../../res/linear_direct_transfer.npy".
        num_execs (int, optional): Total number of executions. Defaults to 100.
    """

    # load and organize the data
    results = np.load(file_path, allow_pickle=True)

    for key in keys:
        reference_times = np.array([res[key]["reference_time"] for res in results])
        transfer_times = np.array([res[key]["transfer_time"] for res in results])
        partial_times = np.array([res[key]["partial_time"] for res in results])

        # plot the time performance
        xs = ['reference', 'transfer', 'partial']
        ys = [reference_times[:, 0], transfer_times[:, 0], partial_times[:, 0]]
        title = f'Evaluation of transfer retrain for {key} DS'
        xlabel = "Transfer policy"
        ylabel = "Optimization time (seconds)"

        fig = plt.figure(figsize=PlotConfigs.FIGURE_SIZE, dpi=150)
        axes = plt.gca()

        axes.boxplot(ys, meanline=True, showmeans=True)
        axes.yaxis.grid(True, linestyle='-', which='major', color='lightgrey', alpha=0.5)

        axes.set_ylabel(ylabel, fontsize=10)
        axes.set_xlabel(xlabel, fontsize=10)

        axes.set_xticklabels(xs, fontsize=8)
        axes.set_title(title, fontsize=14)
        plt.savefig(f'time_performance_{key}.png')
        plt.show()

        # plot the number of iterations
        ys = [reference_times[:, 1], transfer_times[:, 1], partial_times[:, 1]]
        ylabel = "Optimization iterations"

        fig = plt.figure(figsize=PlotConfigs.FIGURE_SIZE, dpi=PlotConfigs.FIGURE_DPI)
        axes = plt.gca()

        axes.boxplot(ys, meanline=True, showmeans=True)
        axes.yaxis.grid(True, linestyle='-', which='major', color='lightgrey', alpha=0.5)

        axes.set_ylabel(ylabel, fontsize=10)
        axes.set_xlabel(xlabel, fontsize=10)

        axes.set_xticklabels(xs, fontsize=8)
        axes.set_title(title, fontsize=14)
        plt.savefig(f'num_iters_performance_{key}.png')
        plt.show()


def plot_contours(lpf, range: np.ndarray = [-50, 50, -50, 50], step_size: float = 0.001,
             save_dir: str = "", file_name: str = "", color: str = 'Greens_r'):
    """Heatmap of an LPF function given a certain range.

    Args:
        lpf (Funciton): The function to plot.
        range (np.ndarray, optional): Ranges on both x and y axis in order.
            Defaults to [-10, 10, -10, 10].
        save_dir(str, optional): Provide a save directory for the figure. Leave empty to skip saving.
    """

    fig = plt.figure(figsize=PlotConfigs.FIGURE_SIZE, dpi=PlotConfigs.FIGURE_DPI)

    x = np.linspace(range[0], range[1], 100)
    y = np.linspace(range[2], range[3], 100)
    X, Y = np.meshgrid(x, y)

    data = np.concatenate([X.reshape(-1,1), Y.reshape(-1,1)], axis=1)
    Z = np.apply_along_axis(lpf, 1, data).reshape(100, 100)
    print(f'{np.min(Z)}, {np.max(Z)}')

    Z /= np.linalg.norm(Z)
    step = np.abs(step_size)

    plt.contour(Z, cmap=color, levels=np.arange(np.min(Z), np.max(Z) + step, step))
    plt.colorbar()

    plt.xlabel('X1', fontsize=PlotConfigs.LABEL_SIZE)
    plt.ylabel('X2', fontsize=PlotConfigs.LABEL_SIZE)

    if save_dir != "":
        name = file_name if file_name != "" else 'plot'
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(os.path.join(save_dir, name), dpi=PlotConfigs.FIGURE_DPI, bbox_inches='tight')
    else:
        plt.show()
