from __future__ import print_function
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
import jacinle.io as io
from sklearn.neighbors import KernelDensity

# * cfg related


def sample2density_1d(cfg, P_k, results_save_path=None, load_epoch=None, save=False, energy_or_density='density'):
    kde = KernelDensity(kernel='gaussian', bandwidth=cfg.band_width).fit(P_k)
    x_axis = np.linspace(-cfg.plot_size, cfg.plot_size, 100).reshape(-1, 1)

    if energy_or_density == 'energy':
        energy = kde.score_samples(x_axis)
        plt.plot(x_axis, energy)
        result = energy
    else:
        density = np.exp(kde.score_samples(x_axis))
        plt.plot(x_axis, density)
        result = density
    if save:
        save_fig(results_save_path +
                 f'/epoch{load_epoch}.png')
    return x_axis, result


# * distribution orientated


def GM_density_1d(cfg, density):
    x_axis = torch.linspace(
        -cfg.plot_size,
        cfg.plot_size, 100).reshape(-1, 1)
    x_plot = x_axis
    mean_q_list, inv_cov_q = density[0].detach().cpu(), density[1].detach().cpu()
    q_density = 0
    weight = 1 / cfg.NUM_GMM_COMPONENT[1]
    for idx in range(cfg.NUM_GMM_COMPONENT[1]):
        q_density += weight / torch.sqrt(2 * np.pi * torch.det(1 / inv_cov_q[idx])) * torch.exp(
            -(((x_axis - mean_q_list[idx]) @ (inv_cov_q[idx])) *
                (x_axis - mean_q_list[idx])).sum(axis=1) / 2)
    plt.plot(x_plot, q_density, c='C1')


def GM_density_2d(cfg, density, save_path=None):
    x, y, pos_plot = grid_NN_2_generator(
        100, -cfg.plot_size, cfg.plot_size)
    pos_plot = torch.from_numpy(pos_plot).float()
    mean_q_list, cov_q = density[0].detach().cpu(), density[1].detach().cpu()
    q_density = 0
    weight = 1 / cfg.NUM_GMM_COMPONENT[1]
    for idx in range(cfg.NUM_GMM_COMPONENT[1]):
        q_density += weight / torch.sqrt(torch.det(2 * np.pi * cov_q[idx])) * torch.exp(
            -(((pos_plot - mean_q_list[idx]) @ torch.inverse(cov_q[idx])) *
                (pos_plot - mean_q_list[idx])).sum(axis=1) / 2)
    contour_alone(x, y, q_density.reshape(100, 100))
    save_fig(save_path)
# * handles


def sns_scatter_handle(data_np_n_2, left_place, right_place, save_path, figsize=(10, 10), opacity=1, scatter_size=10, new_fig=True):
    parameter = Axis_Params(left_place, right_place,
                            figsize, opacity=opacity, scatter_size=scatter_size, new_fig=new_fig)
    sns_scatter_alone(data_np_n_2, parameter)
    save_fig(save_path)


def plt_scatter_3dhandle(data_np_n_2, left_place, right_place, save_path, figsize=(10, 10), opacity=1, scatter_size=10):
    parameter = Axis_Params_3d(
        left_place=left_place, right_place=right_place, figsize=figsize, opacity=opacity, scatter_size=scatter_size, x_rotate=20, z_rotate=45)
    plt_scatter_3d_alone(data_np_n_2, parameter)
    save_fig(save_path)


def plot_3dmarginal(train_data, results_save_path, left_place, right_place):
    plt_total_data = train_data.detach().numpy()
    plt_data = np.concatenate(
        [plt_total_data[:, :, 0], plt_total_data[:, :, 1]], axis=0)
    plt_scatter_3dhandle(plt_data, left_place, right_place,
                         results_save_path + '/2blocks.png')


def plot_rgb_cloud_alone(cloud, save_path, num_point=1024):
    index = np.random.choice(cloud.shape[0], num_point)
    ind_order = np.array([2, 1, 0])
    selected_cloud = cloud[index][:, ind_order]
    parameter = Axis_Params_3d(left_place=0, right_place=1,
                               colors=selected_cloud
                               #    , xlabel='Red', ylabel='Green', zlabel='Blue'
                               )
    plt_scatter_3d_alone(selected_cloud, parameter)
    save_fig(save_path)

# * vector field


def vector_field(pos, y_psf, num_grid, path):
    x_data = pos.reshape(-1, 2)
    plt.quiver(pos[:, :, 0], pos[:, :, 1],
               (y_psf - x_data)[:, 0].reshape(num_grid, num_grid), (y_psf - x_data)[:, 1].reshape(num_grid, num_grid))
    plt.savefig(path)

# * set the grid and index


def grid_NN_2_generator(num_grid, left_place, right_place):
    x, y = xyIndex_generator(num_grid, left_place, right_place)
    x_plot = x.reshape(-1, 1)[:, 0]
    y_plot = y.reshape(-1, 1)[:, 0]
    pos_plot = np.stack((x_plot, y_plot)).T
    return x, y, pos_plot


def grid_N_N_2_generator(num_grid, left_place, right_place):
    x, y = xyIndex_generator(num_grid, left_place, right_place)
    pos = np.dstack((x, y))
    return pos


def xyIndex_generator(num_grid, left_place, right_place):
    # x1 = np.linspace(left_place, right_place, num_grid)
    # y1 = np.linspace(left_place, right_place, num_grid)
    # x, y = np.meshgrid(x1, y1)
    grid_size = (right_place - left_place) / num_grid
    x, y = np.mgrid[left_place:right_place:grid_size,
                    left_place: right_place: grid_size]
    return x, y

# * scatter


def plt_scatter_3d_alone(sample_n_3, ax_params):
    fig = plt.figure(figsize=ax_params.figsize)
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(sample_n_3[:, 0], sample_n_3[:, 1],
               sample_n_3[:, 2], alpha=ax_params.opacity, s=ax_params.scatter_size, c=ax_params.colors)
    set_matplotlib_axis(ax, ax_params)
    ax.view_init(elev=ax_params.x_rotate, azim=ax_params.z_rotate)
    # you will need this line to change the Z-axis
    ax.autoscale(enable=False, axis='both')
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    ax.set_xbound(ax_params.left_place, ax_params.right_place)
    ax.set_ybound(ax_params.left_place, ax_params.right_place)
    ax.set_zbound(ax_params.left_place, ax_params.right_place)
    ax.set_xlabel(ax_params.xlabel)
    ax.set_ylabel(ax_params.ylabel)
    ax.set_zlabel(ax_params.zlabel)
    return ax


def sns_scatter_alone(sample_nn_2, ax_params):
    if ax_params.new_fig == True:
        plt.figure(figsize=ax_params.figsize)
    ax = sns.scatterplot(
        x=sample_nn_2[:, 0], y=sample_nn_2[:, 1],
        alpha=ax_params.opacity,
        s=ax_params.scatter_size,
        legend=False)
    ax.axis("off")
    # import matplotlib as mpl
    # ax.get_xaxis().set_minor_locator(mpl.ticker.AutoMinorLocator())
    # ax.get_yaxis().set_minor_locator(mpl.ticker.AutoMinorLocator())
    ax.grid(b=True, which='major', color='w', linewidth=1.0)
    ax.grid(b=True, which='minor', color='w', linewidth=0.5)
    ax.set_xlim(ax_params.left_place, ax_params.right_place)
    ax.set_ylim(ax_params.left_place, ax_params.right_place)
    return ax


def plt_scatter_alone(sample_nn_2, ax_params):
    if ax_params.new_fig == True:
        plt.figure(figsize=ax_params.figsize)
    plt.scatter(
        sample_nn_2[:, 0], sample_nn_2[:, 1],
        edgecolors='grey', color='grey', alpha=0.1)
    plt.axis("off")
    plt.xlim(ax_params.left_place, ax_params.right_place)
    plt.ylim(ax_params.left_place, ax_params.right_place)

# * contour


class DIM2_PLOT:
    def __init__(self, num_grid=100, left_place=-10, right_place=10, bandwidth=0.1, label_size=15):
        self.ax_param = Axis_Params(
            left_place, right_place, label_font_size=label_size, bandwidth_kde=bandwidth, num_grid=num_grid)
        self.x = None
        self.y = None
        self.kde_density = None

    def contour_from_sample(self, sample_n_2, save_path=None, new_fig=True):
        self.ax_param.new_fig = new_fig
        self.x, self.y, self.xy_list, self.kde_density = sample2density_2d(
            sample_n_2, self.ax_param)
        contour_alone(self.x, self.y, self.kde_density,
                      self.ax_param, save_path)
        return [self.x, self.y, self.xy_list, self.kde_density]

    def scatter(self, sample_n_2, save_path=None, new_fig=True):
        self.ax_param.new_fig = new_fig
        plt_scatter_alone(sample_n_2, self.ax_param)
        if save_path != None:
            save_fig(save_path)


def sample2density_2d(sample_n_2, ax_params):
    x, y, pos_plot = grid_NN_2_generator(
        ax_params.num_grid, ax_params.left_place, ax_params.right_place)

    kde = KernelDensity(
        kernel='gaussian', bandwidth=ax_params.bandwidth).fit(sample_n_2)
    brct_KDE_log = kde.score_samples(pos_plot)
    brct_KDE = np.exp(brct_KDE_log)
    brct_KDE_plot = brct_KDE.reshape(
        ax_params.num_grid, ax_params.num_grid)
    return x, y, pos_plot, brct_KDE_plot


def surface_alone(x, y, z, ax_params=None, save_path=None):
    fig = plt.figure(figsize=ax_params.figsize)
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(
        x, y, z, cmap=cm.coolwarm, linewidth=0, antialiased=False)
    if ax_params != None:
        ax.set_xbound(ax_params.left_place, ax_params.right_place)
        ax.set_ybound(ax_params.left_place, ax_params.right_place)

    if save_path != None:
        save_fig(save_path)


def contour_alone(x, y, z, ax_params=None, save_path=None):
    if ax_params != None and ax_params.new_fig == True:
        plt.figure(figsize=ax_params.figsize)
    plt.contour(x, y, z, levels=10)
    plt.axis("off")
    if ax_params != None:
        plt.xlim(ax_params.left_place, ax_params.right_place)
        plt.ylim(ax_params.left_place, ax_params.right_place)
        plt.tick_params(axis='both', which='major',
                        labelsize=ax_params.label_font_size)
    if save_path != None:
        save_fig(save_path)
    ax = plt.gca()
    return ax


def sns_kdeplot_alone(sample_nn_2, ax_params, bandwidth):
    plt.figure(figsize=ax_params.figsize)
    ax = sns.kdeplot(x=sample_nn_2[:, 0],
                     y=sample_nn_2[:, 1], shade=True, bw=bandwidth)
    return ax


def sns_jointplot_alone(sample_x, sample_y, save_path, ax_params=None):
    plt.figure()
    jg = sns.jointplot(x=sample_x, y=sample_y, kind='kde', thresh=0.05)
    if ax_params is not None:
        jg.ax_joint.set_xlim(ax_params.left_place, ax_params.right_place)
        jg.ax_joint.set_ylim(ax_params.left_place, ax_params.right_place)
    plt.savefig(save_path)
    plt.close()


def error_bar(result_n_expr_n_repeat, label, x_axis=np.array([2, 16, 64, 128, 256]), line_width=6):
    mean = result_n_expr_n_repeat.mean(axis=1)
    std = result_n_expr_n_repeat.std(axis=1)
    plt.errorbar(x_axis, mean, std,
                 label=label, elinewidth=line_width)
#---------------#


def set_sns_axis(ax, ax_params):
    ax.set_xlim(ax_params.left_place, ax_params.right_place)
    ax.set_ylim(ax_params.left_place, ax_params.right_place)
    ax.collections[0].set_alpha(0)
    ax.title.set_text(ax_params.title)
    return ax


def set_matplotlib_axis(ax, ax_params):
    ax.set_xlim(ax_params.left_place, ax_params.right_place)
    ax.set_ylim(ax_params.left_place, ax_params.right_place)
    ax.tick_params(axis='both', which='major',
                   labelsize=ax_params.axis_font_size)
    ax.set_title(ax_params.title, fontsize=ax_params.title_font_size)
    return ax

#---------------#

#! save


def save_fig(path, tight_flag=True):
    if path != None:
        if tight_flag is True:
            plt.savefig(path, bbox_inches='tight')
        else:
            plt.savefig(path)
    plt.close()

#! class of axis


class Axis_Params():
    def __init__(self, left_place, right_place, figsize=(10, 10), title='', label_font_size=15, axis_font_size=15, title_font_size=22, opacity=1, scatter_size=20, new_fig=True, bandwidth_kde=0.9, num_grid=100):
        self.left_place = left_place
        self.right_place = right_place
        self.title = title
        self.label_font_size = label_font_size
        self.axis_font_size = axis_font_size
        self.title_font_size = title_font_size
        self.figsize = figsize
        self.opacity = opacity
        self.scatter_size = scatter_size
        self.new_fig = new_fig
        self.bandwidth = bandwidth_kde
        self.num_grid = num_grid


class Axis_Params_3d(Axis_Params):
    def __init__(self, x_rotate=None, z_rotate=None, colors=None, xlabel=None, ylabel=None, zlabel=None, *kargs, **kwargs):
        super(Axis_Params_3d, self).__init__(*kargs, **kwargs)
        self.x_rotate = x_rotate
        self.z_rotate = z_rotate
        self.colors = colors
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.zlabel = zlabel

#! application: lines


def draw_marginal_lines(path):
    plt.figure(figsize=(10, 10))
    X = io.load_txt(path)

    for i in range(int(X.shape[0] / 2)):
        x_values = X[2 * i:2 * (i + 1), 0]
        y_values = X[2 * i:2 * (i + 1), 1]

        plt.plot(x_values, y_values, 'k')
        plt.xlim(-3, 3)
        plt.ylim(-3, 3)
        plt.axis('off')
