# https://github.com/pytorch/pytorch/issues/19739
import json
import math
import random
from pathlib import Path
from typing import Callable, List

import matplotlib.pyplot as plt
import numpy as np
import open3d
import torch
import torch.optim as optim
import trimesh
from skimage import measure

from implicit_geometric_regularization import click, load_model, mkdir
from implicit_geometric_regularization.datasets.shapenet.uniform_sparse_sampling import (
    CombinedDataset, MeshData, MinibatchGenerator, UniformPointCloudData)
from implicit_geometric_regularization.experiments.learning_shape_space import (
    Decoder, DecoderHyperparameters, LatentOptimization, LossFunction,
    MarchingCubes, Model, ModelHyperparameters, TrainingHyperparameters,
    LatentOptimizationHyperparameters, setup_model)
from implicit_geometric_regularization.datasets.functions import (
    read_mesh_data, read_uniform_point_cloud_data)
from implicit_geometric_regularization.point_cloud import render_mesh, render_point_cloud


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


class Plot:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int],
                 initial_z_func: Callable[[], torch.Tensor],
                 optimize_latent_func: Callable[[torch.Tensor], torch.Tensor],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list
        self.get_initial_z = initial_z_func
        self.optimize_latent = optimize_latent_func
        self.marching_cubes = marching_cubes_func

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = _look_at(eye=eye,
                                                       center=[0, 0, 0],
                                                       up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        figsize_px = np.array([600, 200 * len(self.num_point_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi

        # plot gt
        gt_faces = mesh_data.vertex_indices
        gt_vertices = (mesh_data.vertices + pc_data.offset) * pc_data.scale
        gt_vertices = (rotation_matrix @ gt_vertices.T).T + translation_vector
        gt_mesh = trimesh.Trimesh(vertices=gt_vertices, faces=gt_faces)
        gt_image = render_mesh(gt_mesh, camera_mag=1)

        fig, axes = plt.subplots(len(self.num_point_samples_list),
                                 3,
                                 figsize=figsize_inch)
        if len(self.num_point_samples_list) == 1:
            axes = [axes]

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        for row, num_point_samples in enumerate(self.num_point_samples_list):
            random_state = np.random.RandomState(0)
            data = self.minibatch_generator(
                [pc_data],
                num_point_samples=num_point_samples,
                random_state=random_state)
            initial_z = self.get_initial_z()
            z = self.optimize_latent(data, initial_z)
            print("optimize_latent done")

            # input point cloud
            input_points = data.points[0].detach().cpu().numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            # make prediction
            try:
                mc_vertices, mc_faces = self.marching_cubes(z)
                print("marching_cubes done")
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = render_mesh(mesh, camera_mag=1)
                axes[row][1].imshow(image)
            except ValueError:
                pass
            axes[row][1].set_xticks([])
            axes[row][1].set_yticks([])
            if row == 0:
                axes[row][1].set_title("Reconstruction", fontsize=10)

            axes[row][2].imshow(gt_image)
            axes[row][2].set_xticks([])
            axes[row][2].set_yticks([])
            if row == 0:
                axes[row][2].set_title("Ground truth", fontsize=10)
        return fig


@click.group()
def client():
    pass


@client.command(name="plot_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-data-path", type=str, required=True)
@click.argument("--mesh-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--seed", type=int, default=0)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def plot_data(args,
              latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = CombinedDataset([(args.pc_data_path, args.mesh_data_path)],
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(
        with_normal=training_hyperparams.with_normal, device=device)

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        alpha=training_hyperparams.loss_alpha,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_default_stddev=training_hyperparams.eikonal_term_stddev)

    optimize_latent = LatentOptimization(model=model,
                                         loss_function=loss_function,
                                         params=latent_optimization_params)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    def initial_z_func():
        return torch.normal(mean=0,
                            std=0.01,
                            size=(1, model_hyperparams.z_dim),
                            dtype=torch.float32).to(device)

    num_point_samples_list = [50, 100, 300, 1000]
    plot = Plot(minibatch_generator=minibatch_generator,
                num_point_samples_list=num_point_samples_list,
                initial_z_func=initial_z_func,
                optimize_latent_func=optimize_latent,
                marching_cubes_func=marching_cubes)

    for raw_data_pair in dataset:
        pc_data: UniformPointCloudData = raw_data_pair[0]
        mesh_data: MeshData = raw_data_pair[1]

        fig = plot(pc_data, mesh_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"IGR (Auto-Decoder)\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


@client.command(name="plot_dataset")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-dataset-directory", type=str, required=True)
@click.argument("--mesh-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--seed", type=int, default=0)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def plot_dataset(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    pc_dataset_directory = Path(args.pc_dataset_directory)
    mesh_dataset_directory = Path(args.mesh_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    pc_mesh_path_list = []
    for category_id in split:
        model_id_list = split[category_id]
        for model_id in model_id_list:
            pc_data_path = pc_dataset_directory / category_id / model_id / "point_cloud.npz"
            mesh_data_path = mesh_dataset_directory / category_id / model_id / "models" / "model_normalized.obj"
            if not pc_data_path.exists():
                continue
            pc_mesh_path_list.append((pc_data_path, mesh_data_path))
    print(len(pc_mesh_path_list))

    dataset = CombinedDataset(pc_mesh_path_list,
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(
        with_normal=training_hyperparams.with_normal, device=device)

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        alpha=training_hyperparams.loss_alpha,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples)

    optimize_latent = LatentOptimization(model=model,
                                         loss_function=loss_function,
                                         params=latent_optimization_params)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    def initial_z_func():
        return torch.normal(mean=0,
                            std=0.01,
                            size=(1, model_hyperparams.z_dim),
                            dtype=torch.float32).to(device)

    num_point_samples_list = [50, 100, 500, 1000, 30000]
    plot = Plot(minibatch_generator=minibatch_generator,
                num_point_samples_list=num_point_samples_list,
                initial_z_func=initial_z_func,
                optimize_latent_func=optimize_latent,
                marching_cubes_func=marching_cubes)

    for raw_data_pair in dataset:
        pc_data: UniformPointCloudData = raw_data_pair[0]
        mesh_data: MeshData = raw_data_pair[1]

        fig = plot(pc_data, mesh_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        plt.suptitle(
            f"IGR (Auto-Decoder)\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
            fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
