import json
import math
import sys
import traceback
from pathlib import Path
from typing import Callable, List

import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh

from occupancy_networks import click, load_model, mkdir
from occupancy_networks import point_cloud as pcl
from occupancy_networks.datasets.surface.uniform_sparse_sampling import (
    MeshData, Minibatch, MinibatchGenerator, OccupancySurfacePointCloudData,
    SurfaceSdfMeshPairDataset)
from occupancy_networks.experiments.encoder import (MarchingCubes, Model,
                                                    ModelHyperparameters,
                                                    TrainingHyperparameters,
                                                    setup_model)


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


class Plot:
    def __init__(self, model: Model, data_generator: MinibatchGenerator,
                 num_points_samples_list: List[int],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.model = model
        self.data_generator = data_generator
        self.num_points_samples_list = num_points_samples_list
        self.marching_cubes = marching_cubes_func

    def __call__(self, surface_data: OccupancySurfacePointCloudData,
                 mesh_data: MeshData):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = _look_at(eye=eye,
                                                       center=[0, 0, 0],
                                                       up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        figsize_px = np.array([600, 200 * len(self.num_points_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi

        # plot gt
        gt_faces = mesh_data.vertex_indices
        gt_vertices = (mesh_data.vertices +
                       surface_data.offset) * surface_data.scale
        gt_vertices = (rotation_matrix @ gt_vertices.T).T + translation_vector
        gt_mesh = trimesh.Trimesh(vertices=gt_vertices, faces=gt_faces)
        gt_image = pcl.render_mesh(gt_mesh, camera_mag=1)

        fig, axes = plt.subplots(len(self.num_points_samples_list),
                                 3,
                                 figsize=figsize_inch)
        if len(self.num_points_samples_list) == 1:
            axes = [axes]

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        for row, num_point_samples in enumerate(self.num_points_samples_list):
            print(f"row {row+1} of {len(self.num_points_samples_list)}",
                  flush=True)

            random_state = np.random.RandomState(0)
            minibatch = self.data_generator([surface_data],
                                            num_input_points=num_point_samples,
                                            num_gt_points=0,
                                            random_state=random_state)

            # input point cloud
            input_points = minibatch.input_points[0].detach().cpu().numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = pcl.render_point_cloud(points,
                                           colors,
                                           camera_mag=1,
                                           point_size=6)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            try:
                with torch.no_grad():
                    c = self.model.encode_inputs(minibatch.input_points)
                mc_vertices, mc_faces = self.marching_cubes(c)
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = pcl.render_mesh(mesh, camera_mag=1)
                axes[row][1].imshow(image)
            except ValueError:
                pass
            axes[row][1].set_xticks([])
            axes[row][1].set_yticks([])
            if row == 0:
                axes[row][1].set_title("Reconstruction", fontsize=10)

            if gt_image is not None:
                axes[row][2].imshow(gt_image)
            axes[row][2].set_xticks([])
            axes[row][2].set_yticks([])
            if row == 0:
                axes[row][2].set_title("Ground truth", fontsize=10)
        return fig


@click.group()
def client():
    pass


@client.command(name="plot_data")
@click.argument("--output-directory", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--surface-path", type=str, required=True)
@click.argument("--sdf-path", type=str, required=True)
@click.argument("--obj-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
def plot_data(args):
    device = torch.device("cuda", 0)
    surface_path = Path(args.surface_path)
    sdf_path = Path(args.sdf_path)
    obj_path = Path(args.obj_path)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = SurfaceSdfMeshPairDataset([(surface_path, sdf_path, obj_path)])
    data_generator = MinibatchGenerator(
        num_input_points=training_hyperparams.num_input_points,
        num_gt_points=training_hyperparams.num_gt_points,
        noise_stddev=0,
        device=device)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    num_points_samples_list = [50, 100, 500, 1000, 5000]
    plot = Plot(model=model,
                data_generator=data_generator,
                num_points_samples_list=num_points_samples_list,
                marching_cubes_func=marching_cubes)

    for data_tuple in dataset:
        surface_data = data_tuple[0]
        mesh_data = data_tuple[1]

        fig = plot(surface_data, mesh_data)

        parts = str(mesh_data.path).split("/")
        category_id = parts[-4]
        model_id = parts[-3]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"Proposed Method\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


@client.command(name="plot_dataset")
@click.argument("--output-directory", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--surface-dataset-directory", type=str, required=True)
@click.argument("--sdf-dataset-directory", type=str, required=True)
@click.argument("--obj-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
def plot_dataset(args):
    device = torch.device("cuda", 0)
    surface_dataset_directory = Path(args.surface_dataset_directory)
    sdf_dataset_directory = Path(args.sdf_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    surface_sdf_obj_path_list = []
    for category_id in split:
        model_id_list = split[category_id]
        for model_id in model_id_list:
            surface_path = surface_dataset_directory / category_id / model_id / "point_cloud.npz"
            sdf_path = sdf_dataset_directory / category_id / model_id / "sdf.npz"
            obj_path = obj_dataset_directory / category_id / model_id / "models" / "model_normalized.obj"
            # if not sdf_path.exists():
            #     continue
            # if not surface_path.exists():
            #     continue
            # if not obj_path.exists():
            #     continue
            surface_sdf_obj_path_list.append(
                (surface_path, sdf_path, obj_path))

    dataset = SurfaceSdfMeshPairDataset(surface_sdf_obj_path_list)

    data_generator = MinibatchGenerator(
        num_input_points=training_hyperparams.num_input_points,
        num_gt_points=training_hyperparams.num_gt_points,
        noise_stddev=0,
        device=device)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    num_points_samples_list = [50, 100, 500, 1000, 5000]
    plot = Plot(model=model,
                data_generator=data_generator,
                num_points_samples_list=num_points_samples_list,
                marching_cubes_func=marching_cubes)

    for data_tuple in dataset.shuffle():
        surface_data = data_tuple[0]
        mesh_data = data_tuple[1]

        fig = plot(surface_data, mesh_data)

        parts = str(mesh_data.path).split("/")
        category_id = parts[-4]
        model_id = parts[-3]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"Proposed Method\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
