import json
import math
import sys
import traceback
from pathlib import Path
from typing import Callable, List

import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh

from occupancy_networks import click, load_model, mkdir
from occupancy_networks import point_cloud as pcl
from occupancy_networks.datasets.sdf import (MeshData, Minibatch,
                                             MinibatchGenerator, OccupancyData,
                                             SdfMeshPairDataset)
from occupancy_networks.experiment import (MarchingCubes, Model,
                                           ModelHyperparameters,
                                           TrainingHyperparameters,
                                           setup_model)


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


class Plot:
    def __init__(self, model: Model, data_generator: MinibatchGenerator,
                 num_points_samples_list: List[int],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.model = model
        self.data_generator = data_generator
        self.num_points_samples_list = num_points_samples_list
        self.marching_cubes = marching_cubes_func

    def __call__(self, occupancy_data: OccupancyData, mesh_data: MeshData):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = _look_at(eye=eye,
                                                       center=[0, 0, 0],
                                                       up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        figsize_px = np.array([600, 200 * len(self.num_points_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi

        # plot gt
        try:
            gt_faces = mesh_data.vertex_indices
            gt_vertices = (mesh_data.vertices +
                           occupancy_data.offset) * occupancy_data.scale
            gt_vertices = (rotation_matrix
                           @ gt_vertices.T).T + translation_vector[None, :]
            gt_mesh = trimesh.Trimesh(vertices=gt_vertices, faces=gt_faces)
            gt_image = pcl.render_mesh(gt_mesh, camera_mag=1)
        except IndexError:
            gt_image = None

        fig, axes = plt.subplots(len(self.num_points_samples_list),
                                 3,
                                 figsize=figsize_inch)
        if len(self.num_points_samples_list) == 1:
            axes = [axes]

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        for row, num_point_samples in enumerate(self.num_points_samples_list):
            print(f"row {row+1} of {len(self.num_points_samples_list)}",
                  flush=True)

            random_state = np.random.RandomState(0)
            minibatch = self.data_generator([occupancy_data],
                                            num_input_points=num_point_samples,
                                            num_gt_points=0,
                                            random_state=random_state)

            # input point cloud
            input_points = minibatch.input_points[0].detach().cpu().numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = pcl.render_point_cloud(points,
                                           colors,
                                           camera_mag=1,
                                           point_size=6)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            try:
                with torch.no_grad():
                    q_z = self.model.infer_z(minibatch.input_points,
                                             minibatch.input_occupancies,
                                             c=None)
                    z = q_z.sample()
                mc_vertices, mc_faces = self.marching_cubes(z)
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = pcl.render_mesh(mesh, camera_mag=1)
                axes[row][1].imshow(image)
            except ValueError:
                pass
            axes[row][1].set_xticks([])
            axes[row][1].set_yticks([])
            if row == 0:
                axes[row][1].set_title("Reconstruction", fontsize=10)

            if gt_image is not None:
                axes[row][2].imshow(gt_image)
            axes[row][2].set_xticks([])
            axes[row][2].set_yticks([])
            if row == 0:
                axes[row][2].set_title("Ground truth", fontsize=10)
        return fig


@click.group()
def client():
    pass


@client.command(name="plot_dataset")
@click.argument("--output-directory", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--sdf-dataset-directory", type=str, required=True)
@click.argument("--obj-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
def plot_dataset(args):
    device = torch.device("cuda", 0)
    sdf_dataset_directory = Path(args.sdf_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    sdf_obj_path_list = []
    for category_id in split:
        model_id_list = split[category_id]
        for model_id in model_id_list:
            sdf_path = sdf_dataset_directory / category_id / model_id / "sdf.npz"
            obj_path = obj_dataset_directory / category_id / model_id / "models" / "model_normalized.obj"
            if not sdf_path.exists():
                continue
            if not obj_path.exists():
                continue
            sdf_obj_path_list.append((sdf_path, obj_path))

    dataset = SdfMeshPairDataset(sdf_obj_path_list)

    data_generator = MinibatchGenerator(
        num_input_points=training_hyperparams.num_input_points,
        num_gt_points=training_hyperparams.num_gt_points,
        noise_stddev=0,
        device=device)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    num_points_samples_list = [50, 100, 500, 1000, 5000]
    plot = Plot(model=model,
                data_generator=data_generator,
                num_points_samples_list=num_points_samples_list,
                marching_cubes_func=marching_cubes)

    for data_tuple in dataset:
        occupancy_data = data_tuple[0]
        mesh_data = data_tuple[1]
        try:
            fig = plot(occupancy_data, mesh_data)

            parts = str(mesh_data.path).split("/")
            category_id = parts[-3]
            model_id = parts[-2]

            figure_path = output_directory / f"{category_id}_{model_id}.png"
            plt.tight_layout()
            # plt.suptitle(
            #     f"Proposed Method\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
            #     fontsize=6)
            plt.subplots_adjust(top=0.92)
            plt.savefig(figure_path,
                        dpi=300,
                        bbox_inches="tight",
                        pad_inches=0.05)
            plt.close(fig)
            print(figure_path, flush=True)
        except IndexError as e:
            tb = sys.exc_info()[2]
            print("RuntimeError: {}".format(e.with_traceback(tb)))
            traceback.print_tb(tb)


if __name__ == "__main__":
    client()
