import pickle

from absl import app
from absl import flags
import taichi as ti
from matplotlib import animation
import matplotlib.pyplot as plt
import numpy as np
import os
from pyevtk.hl import pointsToVTK
import imageio
from utils.draw_control_plot import visualize_pose
flags.DEFINE_string("gif_dir", None, help="Directory where rollout.pkl are located")
flags.DEFINE_string("ply_dir", None, help="Directory where .ply are located")
flags.DEFINE_string("rollout_dir", None, help="Directory where rollout.pkl are located")
flags.DEFINE_string("rollout_name", None, help="Name of rollout `.pkl` file")
flags.DEFINE_integer("step_stride", 3, help="Stride of steps to skip.")
flags.DEFINE_bool("change_yz", True, help="Change y and z axis.")
flags.DEFINE_enum("output_mode", "gif", ["gif", "vtk","ply","mp4"], help="Type of render output")

FLAGS = flags.FLAGS

TYPE_TO_COLOR = {
    1: "red",  # for droplet
    3: "black",  # Boundary particles.
    0: "green",  # Rigid solids.
    7: "magenta",  # Goop.
    6: "gold",  # Sand.
    5: "blue",  # Water.
}


class Render():
    """
    Render rollout data into gif or vtk files
    """

    def __init__(self, input_dir, input_name, output_dir=None):
        """
            Initialize render class

        Args:
            input_dir (str): Directory where rollout.pkl are located
            input_name (str): Name of rollout `.pkl` file
        """
        # Texts to describe rollout cases for data and render
        rollout_cases = [
            ["ground_truth_rollout", "Reality"], ["predicted_rollout", "GNS"]]
        self.rollout_cases = rollout_cases
        self.input_dir = input_dir
        self.input_name = input_name
        self.output_dir = output_dir #input_dir
        self.output_name = input_name

        # breakpoint()
        # Get trajectory
        with open(f"{self.input_dir}{self.input_name}.pkl", "rb") as file:
            rollout_data = pickle.load(file)
        self.rollout_data = rollout_data
        trajectory = {}

        for rollout_case in rollout_cases:
            trajectory[rollout_case[0]] = np.concatenate(
                [rollout_data["initial_positions"], rollout_data[rollout_case[0]]], axis=0
            )
        self.trajectory = trajectory
        self.loss = self.rollout_data['loss'].item()

        # Trajectory information
        self.dims = trajectory[rollout_cases[0][0]].shape[2]
        self.num_particles = trajectory[rollout_cases[0][0]].shape[1]
        self.num_steps = trajectory[rollout_cases[0][0]].shape[0]
        self.boundaries = rollout_data["metadata"]["bounds"]
        self.particle_type = rollout_data["particle_types"]

    def color_map(self):
        """
        Get color map array for each particle type for visualization
        """
        # color mask for visualization for different material types
        color_map = np.empty(self.num_particles, dtype="object")
        for material_id, color in TYPE_TO_COLOR.items():
            print(material_id, color)
            color_index = np.where(np.array(self.particle_type) == material_id)
            print(color_index)
            color_map[color_index] = color
        color_map = list(color_map)
        return color_map

    def color_mask(self):
        """
        Get color mask and corresponding colors for visualization
        """
        color_mask = []
        for material_id, color in TYPE_TO_COLOR.items():
            mask = np.array(self.particle_type) == material_id
            if mask.any() == True:
                color_mask.append([mask, color])
        return color_mask

    def render_vis(self):
        # Define datacase name
        # breakpoint()
        pred_rollout = self.rollout_cases[1][0]
        gt = self.rollout_cases[0][0]
        xboundary = self.boundaries[0]
        color_mask = self.color_mask()
        is_ob = False
        if len(color_mask) > 1:
            is_ob = False #True
            if color_mask[1][1]=='blue':
                material = 'water'
            else:
                material = 'sand'
        else:
            if color_mask[0][1]=='blue':
                material = 'water'
            else:
                material = 'sand'
        # init_gt = self.trajectory[gt][0]
        # img= visualize_pose(init_gt, xboundary[0], xboundary[1], material, is_ob)
        # imageio.imwrite(f'{self.output_dir}init_gt.png', img)
        # img= visualize_pose(self.trajectory[gt][-1], xboundary[0], xboundary[1], material, is_ob)
        # imageio.imwrite(f'{self.output_dir}gt.png', img)
        video=[]
        # breakpoint()
        os.makedirs(f'{self.output_dir}{self.output_name}', exist_ok=True)
        save_points = [int(self.num_steps * i / 20) for i in range(0, 21)]
        # save_points = [int(self.num_steps * i / 5) for i in range(0, 6)]
        for frame in range(self.num_steps):
            print(f"Render step {frame}/{self.num_steps}")
            positions = self.trajectory[gt][frame]  #pred_rollout
            img= visualize_pose(positions, xboundary[0], xboundary[1], material, is_ob, color_mask)
            video.append(img)
            if frame in save_points:
                imageio.imwrite(f'{self.output_dir}{self.output_name}/{frame:03d}.png', img)  
        imageio.mimsave(f"{self.output_dir}{self.output_name}/render.mp4", video, fps=30)
    def render_gif_animation(
            self, point_size=1, timestep_stride=3, vertical_camera_angle=20, viewpoint_rotation=0.5, change_yz=False
    ):
        """
        Render `.gif` animation from `.pkl` trajectory data.

        Args:
            point_size (int): Size of particle in visualization
            timestep_stride (int): Stride of steps to skip.
            vertical_camera_angle (float): Vertical camera angle in degree
            viewpoint_rotation (float): Viewpoint rotation in degree

        Returns:
            gif format animation
        """
        # Init figures
        fig = plt.figure()
        if self.dims == 2:
            ax1 = fig.add_subplot(1, 2, 1, projection='rectilinear')
            ax2 = fig.add_subplot(1, 2, 2, projection='rectilinear')
            axes = [ax1, ax2]
        elif self.dims == 3:
            ax1 = fig.add_subplot(1, 2, 1, projection='3d')
            ax2 = fig.add_subplot(1, 2, 2, projection='3d')
            axes = [ax1, ax2]

        # Define datacase name
        trajectory_datacases = [self.rollout_cases[0][0], self.rollout_cases[1][0]]
        render_datacases = [self.rollout_cases[0][1], self.rollout_cases[1][1]]

        # Get boundary of simulation
        xboundary = self.boundaries[0]
        yboundary = self.boundaries[1]
        if self.dims == 3:
            zboundary = self.boundaries[2]

        # Get color mask for visualization
        color_mask = self.color_mask()

        # Fig creating function for 2d
        if self.dims == 2:
            def animate(i):
                print(f"Render step {i}/{self.num_steps}")

                fig.clear()
                for j, datacase in enumerate(trajectory_datacases):
                    # select ax to plot at set boundary
                    axes[j] = fig.add_subplot(1, 2, j + 1, autoscale_on=False)
                    axes[j].set_aspect("equal")
                    axes[j].set_xlim([float(xboundary[0]), float(xboundary[1])])
                    axes[j].set_ylim([float(yboundary[0]), float(yboundary[1])])
                    for mask, color in color_mask:
                        axes[j].scatter(self.trajectory[datacase][i][mask, 0],
                                        self.trajectory[datacase][i][mask, 1], s=point_size, color=color)
                    axes[j].grid(True, which='both')
                    axes[j].set_title(render_datacases[j])
                fig.suptitle(f"{i}/{self.num_steps}, Total MSE: {self.loss:.2e}")

        # Fig creating function for 3d
        elif self.dims == 3:
            def animate(i):
                print(f"Render step {i}/{self.num_steps} for {self.output_name}")

                fig.clear()
                for j, datacase in enumerate(trajectory_datacases):
                    # select ax to plot at set boundary
                    axes[j] = fig.add_subplot(1, 2, j + 1, projection='3d', autoscale_on=False)
                    if change_yz == False:
                        axes[j].set_xlim([float(xboundary[0]), float(xboundary[1])])
                        axes[j].set_ylim([float(yboundary[0]), float(yboundary[1])])
                        axes[j].set_zlim([float(zboundary[0]), float(zboundary[1])])
                        for mask, color in color_mask:
                            axes[j].scatter(self.trajectory[datacase][i][mask, 0],
                                            self.trajectory[datacase][i][mask, 1],
                                            self.trajectory[datacase][i][mask, 2], s=point_size, color=color)
                        # rotate viewpoints angle little by little for each timestep
                        axes[j].set_box_aspect(
                            aspect=(float(xboundary[1]) - float(xboundary[0]),
                                    float(yboundary[1]) - float(yboundary[0]),
                                    float(zboundary[1]) - float(zboundary[0])))
                        # axes[j].view_init(elev=vertical_camera_angle, azim=i * viewpoint_rotation)
                        axes[j].grid(True, which='both')
                        axes[j].set_title(render_datacases[j])
                    else:
                        axes[j].set_xlim([float(xboundary[0]), float(xboundary[1])])
                        axes[j].set_ylim([float(zboundary[0]), float(zboundary[1])])
                        axes[j].set_zlim([float(yboundary[0]), float(yboundary[1])])
                        for mask, color in color_mask:
                            axes[j].scatter(self.trajectory[datacase][i][mask, 0],
                                            self.trajectory[datacase][i][mask, 2],
                                            self.trajectory[datacase][i][mask, 1], s=point_size, color=color)
                        # set aspect ratio to equal
                        axes[j].set_box_aspect(
                            aspect=(float(xboundary[1]) - float(xboundary[0]),
                                    float(zboundary[1]) - float(zboundary[0]),
                                    float(yboundary[1]) - float(yboundary[0])))
                        # rotate viewpoints angle little by little for each timestep
                        # axes[j].view_init(elev=vertical_camera_angle, azim=i * viewpoint_rotation)
                        axes[j].grid(True, which='both')
                        axes[j].set_title(render_datacases[j])
                fig.suptitle(f"{i}/{self.num_steps}, Total MSE: {self.loss:.2e}")

        # Creat animation
        ani = animation.FuncAnimation(
            fig, animate, frames=np.arange(0, self.num_steps, timestep_stride), interval=10)

        ani.save(f'{self.output_dir}{self.output_name}.gif', dpi=100, fps=30, writer='imagemagick')
        print(f"Animation saved to: {self.output_dir}{self.output_name}.gif")

    def write_vtk(self):
        """
        Write `.vtk` files for each timestep for each rollout case.
        """
        for rollout_case, label in self.rollout_cases:
            path = f"{self.output_dir}{self.output_name}_vtk-{label}"
            if not os.path.exists(path):
                os.makedirs(path)
            initial_position = self.trajectory[rollout_case][0]
            for i, coord in enumerate(self.trajectory[rollout_case]):
                disp = np.linalg.norm(coord - initial_position, axis=1)
                pointsToVTK(f"{path}/points{i}",
                            np.array(coord[:, 0]),
                            np.array(coord[:, 1]),
                            np.zeros_like(coord[:, 1]) if self.dims == 2 else np.array(coord[:, 2]),
                            data={"displacement": disp})
        print(f"vtk saved to: {self.output_dir}{self.output_name}...")

    def write_ply(self, export_file):
        
        # Define datacase name
        rollout_case = self.rollout_cases[1][0]  # "predicted_rollout"
        writer = ti.tools.PLYWriter(num_vertices=self.num_particles)
        
        for frame in range(self.num_steps):
            positions = self.trajectory[rollout_case][frame]
            writer.add_vertex_pos(positions[:, 0], positions[:, 1], positions[:, 2])
            writer.export_frame(frame, export_file)
            
        
def main(_):
    if not FLAGS.rollout_dir:
        raise ValueError("A `rollout_dir` must be passed.")
    if not FLAGS.rollout_name:
        raise ValueError("A `rollout_name`must be passed.")

    if not os.path.exists(FLAGS.gif_dir):
        os.makedirs(FLAGS.gif_dir)
    if FLAGS.output_mode == "ply":
        if not os.path.exists(FLAGS.ply_dir):
            os.makedirs(FLAGS.ply_dir)
        if not os.path.exists(f"{FLAGS.ply_dir}{FLAGS.rollout_name}"):
            os.makedirs(f"{FLAGS.ply_dir}{FLAGS.rollout_name}")
    render = Render(input_dir=FLAGS.rollout_dir, input_name=FLAGS.rollout_name, output_dir=FLAGS.gif_dir)

    if FLAGS.output_mode == "gif":
        render.render_gif_animation(
            point_size=1,
            timestep_stride=FLAGS.step_stride,
            vertical_camera_angle=20,
            viewpoint_rotation=0.3,
            change_yz=FLAGS.change_yz
        )
    elif FLAGS.output_mode == "mp4":
        render.render_vis()
    elif FLAGS.output_mode == "vtk":
        render.write_vtk()
    elif FLAGS.output_mode == "ply":
        render.write_ply(export_file=f"{FLAGS.ply_dir}{FLAGS.rollout_name}/rollout.ply")


if __name__ == '__main__':
    app.run(main)

