import json
import math
from pathlib import Path
from typing import Callable, List

import colorful
import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh

from pcn import click, load_model, mkdir
from pcn.datasets.partial_sampling import (MeshDataDescription,
                                           MinibatchGenerator,
                                           MinibatchDescription,
                                           GtPartialSamplingData,
                                           PointCloudAndMeshPairDataset)
from pcn.experiment import (Model, ModelHyperparameters,
                            TrainingHyperparameters, setup_model)
from pcn.point_cloud import render_mesh, render_point_cloud


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


class Plot:
    def __init__(self, model: Model, minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int]):
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list

    def __call__(self, pc_data: GtPartialSamplingData,
                 mesh_data: MeshDataDescription):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = _look_at(eye=eye,
                                                       center=[0, 0, 0],
                                                       up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        figsize_px = np.array([800, 200 * len(self.num_point_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi

        # plot gt
        gt_faces = mesh_data.vertex_indices
        gt_vertices = (mesh_data.vertices + pc_data.offset) * pc_data.scale
        gt_vertices = (rotation_matrix @ gt_vertices.T).T + translation_vector
        gt_mesh = trimesh.Trimesh(vertices=gt_vertices, faces=gt_faces)
        gt_image = render_mesh(gt_mesh, camera_mag=1)

        fig, axes = plt.subplots(len(self.num_point_samples_list),
                                 4,
                                 figsize=figsize_inch)
        if len(self.num_point_samples_list) == 1:
            axes = [axes]

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        for row, num_point_samples in enumerate(self.num_point_samples_list):
            print(f"row {row+1} of {len(self.num_point_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            data = self.minibatch_generator([pc_data],
                                            num_input_points=num_point_samples,
                                            random_state=random_state)

            pred_coarse_points, pred_dense_points = self.model(
                data.input_points)

            # input point cloud
            input_points = data.input_points[0].detach().cpu().numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            # coarse
            pred_coarse_points = pred_coarse_points[0].detach().cpu().numpy()
            pred_coarse_points = pred_coarse_points.reshape((-1, 3))
            points = (
                rotation_matrix @ pred_coarse_points.T).T + translation_vector
            colors = cmap_binary(pred_coarse_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)
            axes[row][1].imshow(image)
            axes[row][1].set_xticks([])
            axes[row][1].set_yticks([])
            axes[row][1].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][1].set_title("Coarse output", fontsize=10)

            # dense
            pred_dense_points = pred_dense_points[0].detach().cpu().numpy()
            pred_dense_points = pred_dense_points.reshape((-1, 3))
            points = (
                rotation_matrix @ pred_dense_points.T).T + translation_vector
            colors = cmap_binary(pred_dense_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)
            axes[row][2].imshow(image)
            axes[row][2].set_xticks([])
            axes[row][2].set_yticks([])
            axes[row][2].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][2].set_title("Dense output", fontsize=10)

            # gt
            axes[row][3].imshow(gt_image)
            axes[row][3].set_xticks([])
            axes[row][3].set_yticks([])
            if row == 0:
                axes[row][3].set_title("Ground truth", fontsize=10)
        return fig


@click.group()
def client():
    pass


@client.command(name="plot_dataset")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-dataset-directory", type=str, required=True)
@click.argument("--obj-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--seed", type=int, default=0)
def plot_dataset(args):
    device = torch.device("cuda", 0)
    npz_dataset_directory = Path(args.npz_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    npz_obj_path_list = []
    for category_id in split:
        model_id_list = sorted(split[category_id])
        for model_id in model_id_list:
            npz_path = npz_dataset_directory / category_id / model_id / "partial_point_cloud.npz"
            obj_path = obj_dataset_directory / category_id / model_id / "models" / "model_normalized.obj"
            if not npz_path.exists():
                continue
            if not obj_path.exists():
                continue
            npz_obj_path_list.append((npz_path, obj_path))

    print(len(npz_obj_path_list))

    dataset = PointCloudAndMeshPairDataset(
        npz_obj_path_list,
        num_coarse_points=model_hyperparams.num_coarse_gt_points,
        num_dense_points=model_hyperparams.num_dense_gt_points)
    minibatch_generator = MinibatchGenerator(
        num_input_points=training_hyperparams.num_input_points, device=device)

    num_point_samples_list = [50, 100, 500, 1000, 5000]
    plot = Plot(model=model,
                minibatch_generator=minibatch_generator,
                num_point_samples_list=num_point_samples_list)

    for data_tuple in dataset:
        pc_data: GtPartialSamplingData = data_tuple[0]
        mesh_data: MeshDataDescription = data_tuple[1]

        fig = plot(pc_data, mesh_data)

        parts = str(mesh_data.path).split("/")
        category_id = parts[-4]
        model_id = parts[-3]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"Proposed Method\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
