import json
import math
from pathlib import Path
from typing import Callable, List

import random
import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh

from pcn import click, load_model, mkdir
from pcn.datasets.uniform_sparse_sampling import (MeshDataDescription,
                                                  MinibatchDescription,
                                                  MinibatchGenerator, Dataset,
                                                  GtUniformPointCloudData)
from pcn.experiment import (Model, ModelHyperparameters,
                            TrainingHyperparameters, setup_model)
from pcn.point_cloud import render_point_cloud


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


class Plot:
    def __init__(self, model: Model, minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int]):
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list

    def __call__(self, pc_data: GtUniformPointCloudData):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = _look_at(eye=eye,
                                                       center=[0, 0, 0],
                                                       up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        figsize_px = np.array([800, 200 * len(self.num_point_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi
        fig, axes = plt.subplots(len(self.num_point_samples_list),
                                 4,
                                 figsize=figsize_inch)
        if len(self.num_point_samples_list) == 1:
            axes = [axes]

        def cmap_binary(points: np.ndarray):
            points = points.copy()
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        # input point cloud
        gt_points = pc_data.gt_dense_points
        # input_points = input_points / pc_data.scale - pc_data.offset
        points = (rotation_matrix @ gt_points.T).T + translation_vector
        colors = cmap_binary(gt_points)
        gt_image = render_point_cloud(points,
                                      colors,
                                      camera_mag=1,
                                      point_size=6)

        for row, num_point_samples in enumerate(self.num_point_samples_list):
            print(f"row {row+1} of {len(self.num_point_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            data = self.minibatch_generator([pc_data],
                                            num_input_points=num_point_samples,
                                            random_state=random_state)

            pred_coarse_points, pred_dense_points = self.model(
                data.input_points)

            # input point cloud
            input_points = data.input_points[0].detach().cpu().numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            # coarse
            pred_coarse_points = pred_coarse_points[0].detach().cpu().numpy()
            pred_coarse_points = pred_coarse_points.reshape((-1, 3))
            points = (
                rotation_matrix @ pred_coarse_points.T).T + translation_vector
            colors = cmap_binary(pred_coarse_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)
            axes[row][1].imshow(image)
            axes[row][1].set_xticks([])
            axes[row][1].set_yticks([])
            if row == 0:
                axes[row][1].set_title("Coarse output", fontsize=10)

            # dense
            pred_dense_points = pred_dense_points[0].detach().cpu().numpy()
            pred_dense_points = pred_dense_points.reshape((-1, 3))
            points = (
                rotation_matrix @ pred_dense_points.T).T + translation_vector
            colors = cmap_binary(pred_dense_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)
            axes[row][2].imshow(image)
            axes[row][2].set_xticks([])
            axes[row][2].set_yticks([])
            if row == 0:
                axes[row][2].set_title("Dense output", fontsize=10)

            # gt
            axes[row][3].imshow(gt_image)
            axes[row][3].set_xticks([])
            axes[row][3].set_yticks([])
            if row == 0:
                axes[row][3].set_title("Ground truth", fontsize=10)
        return fig


@click.group()
def client():
    pass


@client.command(name="plot_dataset")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--seed", type=int, default=0)
def plot_dataset(args):
    device = torch.device("cuda", 0)
    dataset_directory = Path(args.dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    npz_path_list = []
    for category_id in split:
        model_id_list = sorted(split[category_id])
        for model_id in model_id_list:
            npz_path = dataset_directory / category_id / "depth" / f"{model_id}.npz"
            npz_path_list.append(npz_path)

    random.shuffle(npz_path_list)
    print(len(npz_path_list))

    dataset = Dataset(npz_path_list,
                      num_coarse_points=model_hyperparams.num_coarse_gt_points,
                      num_dense_points=model_hyperparams.num_dense_gt_points)
    minibatch_generator = MinibatchGenerator(
        num_input_points=training_hyperparams.num_input_points, device=device)

    num_point_samples_list = [50, 100, 300, 1000]
    num_point_samples_list = [100]
    plot = Plot(model=model,
                minibatch_generator=minibatch_generator,
                num_point_samples_list=num_point_samples_list)

    for pc_data in dataset:
        parts = pc_data.path.parts
        category_id = parts[-3]
        model_id = parts[-1].replace(".npz", "")

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        if figure_path.exists():
            continue

        fig = plot(pc_data)
        plt.tight_layout()
        # plt.suptitle(
        #     f"Proposed Method\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
