from vgn.utils.transform import Rotation, Transform
from vgn.perception import *
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from math import ceil, floor
from typing import Union, Tuple


def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct
    RGB color; the keyword argument name must be a standard mpl colormap name.
    Example:
    >>> cmap = get_cmap(count)
    >>> for i in range(count):
    >>>     ...color=cmap(i)...
    '''
    return plt.cm.get_cmap(name, n)


def draw_pyramid(ax, extrinsic, color: Union[str, Tuple[float, float, float]] = 'r',
                 alpha=0.35, height=0.3, intrinsic=None, fov=None, label=None):
    """
    Example:
    >>> fig = plt.figure()
    >>> ax = fig.add_subplot(projection='3d')

    >>> intrinsic = [540.0, 540.0, 320.0, 240.0, 640, 480]
    >>> extrinsic = np.array([[-0.7818,  0.6235,  0.    ,  0.0238],
                              [ 0.3117,  0.3909, -0.866 , -0.1054],
                              [-0.54  , -0.6771, -0.5   ,  0.7826],
                              [ 0.    ,  0.    ,  0.    ,  1.    ]])
    >>> draw_pyramid(ax, extrinsic, intrinsic=intrinsic, height=0.1, label='camera')
    """
    if intrinsic is not None:
        fov_w = intrinsic[4] / intrinsic[0] / 2.0
        fov_h = intrinsic[5] / intrinsic[1] / 2.0
    elif fov is not None:
        fov_w = fov[0]
        fov_h = fov[1]
    else:
        fov_w = 0.5
        fov_h = 0.5
    vertex_std = np.array([[0, 0, 0, 1],  
                           [ height * fov_w, -height * fov_h, height, 1],
                           [ height * fov_w,  height * fov_h, height, 1],
                           [-height * fov_w,  height * fov_h, height, 1],
                           [-height * fov_w, -height * fov_h, height, 1]])

    
    
    vertex_transformed = vertex_std @ np.linalg.inv(extrinsic).T

    ax.scatter(vertex_transformed[:4, 0], vertex_transformed[:4, 1], vertex_transformed[:4, 2], color=color, s=10)
    meshes = [[vertex_transformed[0, :-1], vertex_transformed[1, :-1], vertex_transformed[2, :-1]],
              [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
              [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
              [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
              [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]]
    poly = ax.add_collection3d(
        Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=alpha))

    
    ax.scatter(vertex_transformed[4, 0],
               vertex_transformed[4, 1],
               vertex_transformed[4, 2],
               color='black', s=10)
    ax.plot3D([vertex_transformed[1, 0], vertex_transformed[4, 0]],    
              [vertex_transformed[1, 1], vertex_transformed[4, 1]],    
              [vertex_transformed[1, 2], vertex_transformed[4, 2]],    
              c='black', linewidth=2)

    if label is not None:
        ax.text(vertex_transformed[0, 0], vertex_transformed[0, 1], vertex_transformed[0, 2], label)


def plot_3D_points(ax, pts, color_axis=None, cmap='jet', with_colorbar=True, colorbar_shrink=0.5,
                   size=10, with_origin=True,
                   axis_lim=None, down_sample=None):
    if color_axis is not None:
        if type(color_axis) == int:
            colors = np.zeros(pts.shape[:-1])
            colors = colors.swapaxes(0, color_axis)
            for i in range(colors.shape[0]):
                colors[i] = i
            colors = colors.swapaxes(0, color_axis)
        elif color_axis.shape[:-1] == pts.shape[:-1]:
            colors = color_axis
        else:
            raise ValueError('color_axis not valid')
    pts = np.reshape(pts, (-1, 3))
    colors = np.reshape(colors, (-1, 1))

    if down_sample and down_sample < pts.shape[0]:
        sample_index = np.random.randint(0, pts.shape[0], size=down_sample)
        pts = pts[sample_index]
        if color_axis is not None:
            colors = colors[sample_index]

    if with_origin:
        ax.scatter(0, 0, 0, c='black', marker='x', s=100)  
    if color_axis is not None:
        i = ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=size, c=colors, cmap=cmap)
    else:
        i = ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=size)

    if axis_lim:
        set_axis_range(ax, axis_lim * 3)

    if with_colorbar:
        plt.colorbar(i, shrink=colorbar_shrink)
    return i


def draw_floor_grid(ax, grid_range=[0, 0.3, 0, 0.3], grid_size=0.1, color='k', alpha=0.2):
    x = np.arange(grid_range[0], grid_range[1] + 1e-4, grid_size)
    y = np.arange(grid_range[2], grid_range[3] + 1e-4, grid_size)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros(X.shape)
    surf = ax.plot_wireframe(X, Y, Z, color=color, alpha=alpha)
    return surf


def draw_sphere(ax, center, radius, color='b', alpha=0.2):
    u, v = np.mgrid[0:2 * np.pi:20j, 0:np.pi:10j]
    x = radius * np.cos(u) * np.sin(v) + center[0]
    y = radius * np.sin(u) * np.sin(v) + center[1]
    z = radius * np.cos(v) + center[2]
    ax.plot_wireframe(x, y, z, color=color, alpha=alpha)


def draw_box(ax, origin, size, facecolors='b', edgecolors='k', alpha=0.2):
    x, y, z = np.indices((2, 2, 2)).astype(np.float64)  
    x *= size[0]
    y *= size[1]
    z *= size[2]
    x += origin[0]
    y += origin[1]
    z += origin[2]

    filled = np.ones((1, 1, 1))
    ax.voxels(x, y, z, filled=filled, facecolors=facecolors, edgecolors=edgecolors, alpha=alpha)


def mark_axis_label(ax, axis='xyz', fontsize=10):
    if 'x' in axis:
        ax.set_xlabel('X Label', fontsize=fontsize)
    if 'y' in axis:
        ax.set_ylabel('Y Label', fontsize=fontsize)
    if 'z' in axis:
        ax.set_zlabel('Z Label', fontsize=fontsize)


def assign_axis_group(fig, img_num, plot_rows, img_series=1, subplot_kw=None):
    img_in_a_row = floor(img_num / plot_rows + 1 - 1e-10)
    actual_img_rows = ceil(img_num / img_in_a_row)

    axes = []

    for i in range(img_series):
        axes_in_a_series = []
        for j in range(img_num):
            axes_in_a_series.append(fig.add_subplot(actual_img_rows * img_series,
                                                    img_in_a_row,
                                                    i * actual_img_rows * img_in_a_row + j + 1,
                                                    **subplot_kw))
        axes.append(axes_in_a_series)
    return np.array(axes)

def fix_3d_axis_equal(ax):
    ax.set_box_aspect((1, 1, 1))

def set_axis_range(ax, axis_range):
    ax.set_xlim(axis_range[0], axis_range[1])
    ax.set_ylim(axis_range[2], axis_range[3])
    ax.set_zlim(axis_range[4], axis_range[5])

def make_view_change_animation(fig, ax, frames_count=40, elev=[35, -10], azim=[0, 360]):
    import matplotlib.animation
    plt.rcParams["animation.html"] = "jshtml"

    def animate(i):
        ax.view_init(elev[0] + (elev[1] - elev[0]) * i / frames_count,
                     azim[0] + (azim[1] - azim[0]) * i / frames_count)

    ani = matplotlib.animation.FuncAnimation(fig, animate, frames=frames_count)
    return ani
