# Original implementation from https://github.com/demul/extrinsic2pyramid. Thanks for the work of @demul! 

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import cv2
from matplotlib.patches import Patch
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from mpl_toolkits.mplot3d import Axes3D

from utils.pose import convert3x4_4x4

def plt_to_ndarray(fig): 
    fig.tight_layout(pad=0)
    fig.canvas.draw()

    img = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
    img = img.copy()
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
    img = img[..., ::-1]
    img[..., :3] = img[..., :3][..., ::-1]  #RGBA
    return img


def draw_distribution(x, p, info=None): 
    fig = plt.figure(figsize=(8, 8), dpi=150)
    plt.plot(x, p)
    #plt.text(0, 0, info, ha='left', va='bottom')
    img = plt_to_ndarray(fig)
    plt.clf()
    plt.close('all')# do not close seems can add more images
    return img

class CameraPoseVisualizer: 
    def __init__(self, scale, n_colors=None, dpi=300): 
        xlim, ylim, zlim = [-scale, scale], [-scale, scale], [-scale, scale]
        self.fig = plt.figure(figsize=(8, 8), dpi=dpi)
        self.ax = self.fig.add_subplot(projection=Axes3D.name)
        self.ax.set_aspect("auto")
        self.ax.set_xlim(xlim)
        self.ax.set_ylim(ylim)
        self.ax.set_zlim(zlim)
        self.ax.set_xlabel('x')
        self.ax.set_ylabel('y')
        self.ax.set_zlabel('z')


        if n_colors is not None: 
            self.cmap = mpl.cm.get_cmap('rainbow').resampled(n_colors)
        else: 
            self.cmap = mpl.cm.get_cmap('jet')

    def extrinsic2pyramid(self, extrinsic, color=0, linewidth=0.1, focal_len_scaled=0.05, aspect_ratio=1, fov=40): 
        color = self.cmap(color)
        #color = tuple(np.array(color)[[2,1,0,3]])
        x = focal_len_scaled * np.tan(fov/180*np.pi/2)
        y = x * aspect_ratio
        vertex_std = np.array([[0, 0, 0, 1],
                               [x, -y, -focal_len_scaled, 1],
                               [x, y, -focal_len_scaled, 1],
                               [-x, y, -focal_len_scaled, 1],
                               [-x, -y, -focal_len_scaled, 1], 
                               [0, 2*y, -focal_len_scaled, 1]])
        vertex_transformed = vertex_std @ extrinsic.T
        meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
                            [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
                            [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
                            [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
                            [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], 
                            [vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[5, :-1]]]
        
        self.ax.add_collection3d(
            Poly3DCollection(meshes, facecolors=color, linewidths=linewidth, edgecolors="black", alpha=0.8))

    def customize_legend(self, list_label):
        list_handle = []
        for idx, label in enumerate(list_label):
            color = plt.cm.rainbow(idx / len(list_label))
            patch = Patch(color=color, label=label)
            list_handle.append(patch)
        plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle)

    def colorbar(self, max_frame_length): 
        cmap = mpl.cm.rainbow
        norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length)
        self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), orientation='vertical', label='Frame Number')

    def add_pcd(self, xyz, color, point_size=0.0001): 
        #color = self.cmap(color)
        if color is None: 
            color = 'black'
        self.ax.scatter(xyz[:,0], xyz[:,1], xyz[:,2], \
            s=point_size*np.ones_like(xyz[:,0]), color=color, marker='o')

    def save(self, img_path): 
        plt.savefig(img_path, bbox_inches='tight')
        plt.clf()
        plt.cla()
        plt.close()

    def show(self):
        plt.title('Extrinsic Parameters')
        plt.show()

    def to_ndarray(self): 
        img = plt_to_ndarray(self.fig)
        return img

    def multiple_angles(self, n_images): 
        img_list = []

        for i in range(n_images): 
            self.ax.view_init(elev=20, azim=-60+(i-1)*360/n_images, roll=0)
            no_background = True
            if no_background: 
                self.fig.patch.set_alpha(0.0)
                self.ax.patch.set_alpha(0.0)
                self.ax.axis('off')

            img = self.to_ndarray()
            img_list.append(img)

        return img_list

def visualize_matplotlib(scale, list_of_c2w_list, linewidth=0.5, camera_color=None, show=False, mesh_path=None, pcd_size=1, mesh_color=None, camera_size=0.4, xyz_np=None, n_plot_view=4, save_path=None, viewspace_size_ratio=2): 
    n_list = len(list_of_c2w_list)
    size = viewspace_size_ratio*scale

    if camera_color is not None: 
        assert len(list_of_c2w_list) == 1, "camera_color only support one list of camera traj"
        visualizer = CameraPoseVisualizer(size)
    else: 
        visualizer = CameraPoseVisualizer(size, n_colors=n_list+1)
    # argument : extrinsic matrix, color, scaled focal length(z-axis length of frame body of camera

    for i, c2w_list in enumerate(list_of_c2w_list): 
        for j, c2w in enumerate(c2w_list): 
            c2w = convert3x4_4x4(c2w)
            c2w = c2w.cpu().numpy()
            if camera_color is not None: 
                visualizer.extrinsic2pyramid(c2w, camera_color[j], linewidth, camera_size)
            else: 
                visualizer.extrinsic2pyramid(c2w, i, linewidth, camera_size)
    
    if xyz_np is not None: 
        visualizer.add_pcd(xyz_np, color=mesh_color, point_size=pcd_size)
    #visualizer.colorbar(max_frame_length)
    if save_path is not None: 
        visualizer.save(save_path) # warning: not compatible with the following multiple angles
    
    return visualizer.multiple_angles(n_plot_view)

def plot_loss_curve(x, y, save_path=None, resize=None): 
    f = plt.figure()
    plt.clf()
    plt.plot(x, y)
    plt.title('Loss versus step')
    plt.xlabel("step")
    plt.ylabel("loss")
    
    if save_path is not None: 
        plt.savefig(save_path, bbox_inches='tight')
    else: 
        img = plt_to_ndarray(f)
        img = cv2.resize(img, resize)
        plt.clf()
        plt.cla()
        plt.close()
        return img

def test(): 
    import cv2
    for i in range(8): 
        x, y = np.arange(0, 10, 1), np.arange(0, 10, 1)
        img = draw_distribution(x, y, info="good")
        cv2.imwrite("test.jpg", img)

def test_plot(): 
    x = range(3)
    y = [0,1,2]
    plot_loss_curve(x, y, "test.png")

def create_test_traj(radius, focus_depth, upper_half=True, n_poses=(10, 10)): 
    poses = []

    lat_array = np.linspace(0 if upper_half else -0.5*np.pi, 0.5*np.pi, n_poses[0]+1)[:-1] # Veritcally
    longit_array = np.linspace(0, 2*np.pi, n_poses[1]+1)[:-1] #horizontally

    for longit in longit_array:  
        for lat in lat_array: 
            center = np.array([np.cos(longit)*np.sin(lat), np.sin(longit)*np.sin(lat), np.cos(lat)]) * radius
            z = normalize(center - np.array([0, 0, -focus_depth]))

            # compute other axes as in @average_poses
            y_ = np.array([0, 0, 1])  # (3)
            x = normalize(np.cross(y_, z))  # (3)
            y = np.cross(z, x)  # (3)

            poses += [np.stack([x, y, z, center], 1)]  # (3, 4)

    return np.stack(poses, 0)  # (n_poses, 3, 4)
