# https://github.com/pytorch/pytorch/issues/19739
import json
import math
import random
from pathlib import Path
from typing import Callable, List

import matplotlib.pyplot as plt
import numpy as np
import open3d
import torch
import torch.optim as optim
import trimesh
from skimage import measure

from deep_sdf import click, load_model, mkdir
from deep_sdf import point_cloud as pcl
from deep_sdf.datasets.shapenet import (CombinedDataset, MeshData, Minibatch,
                                        MinibatchGenerator, SdfData)
from deep_sdf.experiments.learning_shape_space import (
    Decoder, DecoderHyperparameters, LatentOptimization, LossFunction,
    MarchingCubes, Model, ModelHyperparameters, TrainingHyperparameters,
    look_at, sample_surface_points, setup_model)


class Plot:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 num_sdf_samples_list: List[int],
                 initial_z_func: Callable[[], torch.Tensor],
                 optimize_latent_func: Callable[[torch.Tensor], torch.Tensor],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.minibatch_generator = minibatch_generator
        self.num_sdf_samples_list = num_sdf_samples_list
        self.get_initial_z = initial_z_func
        self.optimize_latent = optimize_latent_func
        self.marching_cubes = marching_cubes_func

    def __call__(self, sdf_data: SdfData, mesh_data: MeshData):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = look_at(eye=eye,
                                                      center=[0, 0, 0],
                                                      up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        figsize_px = np.array([600, 200 * len(self.num_sdf_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi

        # plot gt
        gt_faces = mesh_data.vertex_indices
        gt_vertices = (mesh_data.vertices + sdf_data.offset) * sdf_data.scale
        gt_vertices = (rotation_matrix @ gt_vertices.T).T + translation_vector
        gt_mesh = trimesh.Trimesh(vertices=gt_vertices, faces=gt_faces)
        gt_image = pcl.render_mesh(gt_mesh, camera_mag=1)

        fig, axes = plt.subplots(len(self.num_sdf_samples_list),
                                 3,
                                 figsize=figsize_inch)
        if len(self.num_sdf_samples_list) == 1:
            axes = [axes]

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        for row, num_sdf_samples in enumerate(self.num_sdf_samples_list):
            random_state = np.random.RandomState(0)
            data = self.minibatch_generator([sdf_data],
                                            num_sdf_samples=num_sdf_samples,
                                            random_state=random_state)

            initial_z = self.get_initial_z()
            z = self.optimize_latent(data, initial_z)
            print("optimize_latent done")

            # input point cloud
            input_points = data.points[0].detach().cpu().numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = pcl.render_point_cloud(points,
                                           colors,
                                           camera_mag=1,
                                           point_size=6)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_sdf_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            # make prediction
            try:
                mc_vertices, mc_faces = self.marching_cubes(z)
                print("marching_cubes done")
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = pcl.render_mesh(mesh, camera_mag=1)
                axes[row][1].imshow(image)
            except ValueError:
                pass
            axes[row][1].set_xticks([])
            axes[row][1].set_yticks([])
            if row == 0:
                axes[row][1].set_title("Reconstruction", fontsize=10)

            if False:
                gt_faces = mesh_data.vertex_indices
                gt_vertices = mesh_data.vertices
                gt_surface_points = sample_surface_points(
                    gt_faces, gt_vertices, 30000)
                pred_surface_points = sample_surface_points(
                    mc_faces, mc_vertices, 30000)
                chamfer_distance = compute_chamfer_distance(
                    gt_surface_points, pred_surface_points)
                print(num_sdf_samples, chamfer_distance)

            axes[row][2].imshow(gt_image)
            axes[row][2].set_xticks([])
            axes[row][2].set_yticks([])
            if row == 0:
                axes[row][2].set_title("Ground truth", fontsize=10)
        return fig


@click.group()
def client():
    pass


@client.command(name="plot_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-data", type=str, required=True)
@click.argument("--obj-data", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--latent-optimization-iterations", type=int, default=800)
@click.argument("--latent-optimization-initial-lr", type=float, default=0.005)
@click.argument("--seed", type=int, default=0)
def plot_data(args):
    device = torch.device("cuda", 0)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = SdfAndMeshPairDataset([(args.npz_data, args.obj_data)])
    minibatch_generator = MinibatchGenerator(device=device)

    loss_function = LossFunction(lam=1e-4, clamping_distance=0.1)

    optimize_latent = LatentOptimization(
        model=model,
        loss_function=loss_function,
        lr=args.latent_optimization_initial_lr,
        decrease_lr_every=args.latent_optimization_iterations // 2,
        max_iterations=args.latent_optimization_iterations)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    def initial_z_func():
        return torch.normal(mean=0,
                            std=0.01,
                            size=(1, model_hyperparams.z_dim),
                            dtype=torch.float32).to(device)

    num_sdf_samples_list = [50, 100, 500, 1000, 30000]
    num_sdf_samples_list = [100]
    plot = Plot(minibatch_generator=minibatch_generator,
                num_sdf_samples_list=num_sdf_samples_list,
                initial_z_func=initial_z_func,
                optimize_latent_func=optimize_latent,
                marching_cubes_func=marching_cubes)

    for raw_data_pair in dataset:
        sdf_data: SdfData = raw_data_pair[0]
        mesh_data: MeshData = raw_data_pair[1]

        fig = plot(sdf_data, mesh_data)

        parts = str(sdf_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"DeepSDF (Auto-Decoder)\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


@client.command(name="plot_dataset")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-dataset-directory", type=str, required=True)
@click.argument("--obj-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--latent-optimization-iterations", type=int, default=800)
@click.argument("--latent-optimization-initial-lr", type=float, default=0.005)
@click.argument("--seed", type=int, default=0)
def plot_dataset(args):
    device = torch.device("cuda", 0)
    npz_dataset_directory = Path(args.npz_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    npz_obj_path_list = []
    for category_id in split:
        model_id_list = split[category_id]
        for model_id in model_id_list:
            npz_path = npz_dataset_directory / category_id / model_id / "sdf.npz"
            obj_path = obj_dataset_directory / category_id / model_id / "models" / "model_normalized.obj"
            if not npz_path.exists():
                continue
            npz_obj_path_list.append((npz_path, obj_path))

    print(len(npz_obj_path_list))

    dataset = SdfAndMeshPairDataset(npz_obj_path_list)
    minibatch_generator = MinibatchGenerator(device=device)

    loss_function = LossFunction(
        lam=training_hyperparams.loss_lam,
        clamping_distance=training_hyperparams.clamping_distance)

    optimize_latent = LatentOptimization(
        model=model,
        loss_function=loss_function,
        lr=args.latent_optimization_initial_lr,
        decrease_lr_every=args.latent_optimization_iterations // 2,
        max_iterations=args.latent_optimization_iterations)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    def initial_z_func():
        return torch.normal(mean=0,
                            std=0.01,
                            size=(1, model_hyperparams.z_dim),
                            dtype=torch.float32).to(device)

    num_sdf_samples_list = [50, 100, 500, 1000, 30000]
    plot = Plot(minibatch_generator=minibatch_generator,
                num_sdf_samples_list=num_sdf_samples_list,
                initial_z_func=initial_z_func,
                optimize_latent_func=optimize_latent,
                marching_cubes_func=marching_cubes)

    for raw_data_pair in dataset:
        sdf_data: SdfData = raw_data_pair[0]
        mesh_data: MeshData = raw_data_pair[1]

        fig = plot(sdf_data, mesh_data)

        parts = str(sdf_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        plt.suptitle(
            f"DeepSDF (Auto-Decoder)\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
            fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
