import json
import math
import shutil
from collections import defaultdict
from pathlib import Path

import colorful
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
from tabulate import tabulate

from meta_learning_sdf import click, load_model, mkdir
from meta_learning_sdf import point_cloud as pcl
from meta_learning_sdf.datasets.shapenet.uniform_sparse_sampling import (
    MeshData, MinibatchGenerator, PointCloudAndMeshPairDataset,
    UniformPointCloudData)
from meta_learning_sdf.experiments.baseline import (
    LatentOptimization, LossFunction, MarchingCubes, Model,
    ModelHyperparameters, TrainingHyperparameters, setup_model)


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


@click.group()
def client():
    pass


@client.command(name="plot_dataset")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-dataset-directory", type=str, required=True)
@click.argument("--obj-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=512)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
def plot_dataset(args):
    device = torch.device("cuda", 0)
    npz_dataset_directory = Path(args.npz_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    npz_obj_path_list = []
    for category_id in split:
        model_id_list = split[category_id]
        for model_id in model_id_list:
            if model_id != "3469d3ab2d353da43a3afd30f2e86bd7":
                continue
            if category_id != "02958343":
                continue

            npz_path = npz_dataset_directory / category_id / model_id / "point_cloud.npz"
            obj_path = obj_dataset_directory / category_id / model_id / "models" / "model_normalized.obj"
            npz_obj_path_list.append((npz_path, obj_path))

    random.shuffle(npz_obj_path_list)
    print(len(npz_obj_path_list))

    dataset = PointCloudAndMeshPairDataset(npz_obj_path_list)
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        if training_hyperparams.anneal_kld_weight:
            return training_hyperparams.loss_kld_final_weight
        return training_hyperparams.loss_kld_initial_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    grid_max_value = 1
    grid_min_value = -1

    latent_optimization_iterations_list = [0, 400]
    latent_optimization_lr_list = [0.0001]
    latent_optimization_num_samples_list = [100]

    for data_tuple in dataset:
        pc_data = data_tuple[0]
        mesh_data = data_tuple[1]

        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices
        gt_vertices = (gt_vertices + pc_data.offset) * pc_data.scale
        mesh_surface_points = _sample_surface_points(gt_faces, gt_vertices,
                                                     500000)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        figure_path = output_directory / f"{category_id}_{model_id}.png"
        if figure_path.exists():
            continue

        figsize_px = np.array([1000, 1500])
        dpi = 100
        figsize_inch = figsize_px / dpi
        fig = plt.figure(figsize=figsize_inch)

        for col, latent_optimization_iterations in enumerate(
                latent_optimization_iterations_list):
            for row, (latent_optimization_lr,
                      latent_optimization_num_samples) in enumerate(
                          zip(latent_optimization_lr_list,
                              latent_optimization_num_samples_list)):

                minibatch = minibatch_generator(
                    [pc_data],
                    min_num_context=latent_optimization_num_samples,
                    max_num_context=latent_optimization_num_samples,
                    min_num_target=0,
                    max_num_target=0)

                # -------------------------------------------------------------------------
                # Ground truth
                # -------------------------------------------------------------------------
                plt.subplot(3, 2, 2 * row + col + 1)
                x = mesh_surface_points[:, 0]
                shift = -0.15
                eps = 1e-3
                slice_indices = np.where((-eps <
                                          (x - shift)) * ((x - shift) < eps))
                slice_points = mesh_surface_points[slice_indices]
                x = slice_points[:, 2]
                y = slice_points[:, 1]
                print(np.min(x), np.max(x))
                plt.scatter(x, y, c="k", marker="o", s=1)

                # -------------------------------------------------------------------------
                # Model
                # -------------------------------------------------------------------------
                optimize_latent = LatentOptimization(
                    model=model,
                    loss_function=loss_function,
                    lr=latent_optimization_lr,
                    decrease_lr_every=latent_optimization_iterations,
                    max_iterations=latent_optimization_iterations)
                with torch.no_grad():
                    initial_h, h_dist = model.encoder(
                        minibatch.context_points_list[0])
                h = optimize_latent(minibatch, initial_h)

                grid = np.linspace(grid_min_value, grid_max_value,
                                   args.grid_size)
                xx, yy = np.meshgrid(grid, grid)
                zz = np.zeros_like(xx)
                grid_3d = np.stack((zz, xx, yy)).reshape(
                    (3, -1)).transpose().astype(np.float32)
                grid_3d = torch.from_numpy(grid_3d).to(device)[None, :, :]

                with torch.no_grad():
                    repeats = grid_3d.shape[1]
                    _h = h[:, None, :].expand(
                        (h.shape[0], repeats, h.shape[1]))

                    f = model.decoder(X=grid_3d, h=_h).squeeze(dim=2)[0]
                    f = f.cpu().numpy()

                ff = f.reshape((args.grid_size, args.grid_size))

                plt.subplot(3, 2, 2 * row + col + 1)
                cs = plt.contour(yy,
                                 xx,
                                 ff,
                                 cmap="coolwarm_r",
                                 vmin=-1,
                                 vmax=1,
                                 levels=10)
                plt.clabel(cs, inline=1, fontsize=10)
                plt.xlim([-1, 1])
                plt.ylim([-1, 1])
                plt.xticks([])
                plt.yticks([])
                plt.gca().set_aspect("equal", adjustable="box")
                if col == 0:
                    plt.ylabel(latent_optimization_num_samples)
                if row == 0:
                    plt.title(f"iterations={latent_optimization_iterations}")

        plt.tight_layout()
        plt.suptitle(f"{category_id}_{model_id}")
        plt.subplots_adjust(top=0.94)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


@client.command(name="compare_eikonal_term")
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-dataset-directory", type=str, required=True)
@click.argument("--obj-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=512)
@click.argument("--latent-optimization-iterations", type=int, default=400)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
def compare_eikonal_term(args):
    device = torch.device("cuda", 0)
    npz_dataset_directory = Path(args.npz_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    def _load_model(checkpoint_directory: Path):
        args_path = checkpoint_directory / "args.json"
        if args.checkpoint_epoch is None:
            model_path = checkpoint_directory / "model.pt"
        else:
            model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
        assert args_path.is_file()
        assert model_path.is_file()

        model_hyperparams = ModelHyperparameters.load_json(args_path)
        training_hyperparams = TrainingHyperparameters.load_json(args_path)

        model = setup_model(model_hyperparams)
        load_model(model_path, model)
        model.to(device)
        model.eval()

        return model, model_hyperparams, training_hyperparams

    def _generate_sdf(model: Model, loss_function: LossFunction):
        optimize_latent = LatentOptimization(
            model=model,
            loss_function=loss_function,
            lr=latent_optimization_lr,
            decrease_lr_every=args.latent_optimization_iterations,
            max_iterations=args.latent_optimization_iterations)

        with torch.no_grad():
            initial_h, h_dist = model.encoder(data.context_points_list[0])
        h = optimize_latent(data, initial_h)

        grid = np.linspace(grid_min_value, grid_max_value, args.grid_size)
        xx, yy = np.meshgrid(grid, grid)
        zz = np.zeros_like(xx)
        grid_3d = np.stack((zz, xx, yy)).reshape(
            (3, -1)).transpose().astype(np.float32)
        grid_3d = torch.from_numpy(grid_3d).to(device)[None, :, :]

        with torch.no_grad():
            repeats = grid_3d.shape[1]
            _h = h[:, None, :].expand((h.shape[0], repeats, h.shape[1]))

            f = model.decoder(X=grid_3d, h=_h).squeeze(dim=2)[0]
            f = f.cpu().numpy()

        ff = f.reshape((args.grid_size, args.grid_size))

        return xx, yy, ff

    checkpoint_directory_1 = Path(
        "checkpoints/compare_eikonal_term/result_f23fd0e76fa0")
    checkpoint_directory_2 = Path(
        "checkpoints/compare_eikonal_term/result_707f36b05178")

    model_1, model_hyperparams_1, training_hyperparams_1 = _load_model(
        checkpoint_directory_1)
    model_2, model_hyperparams_2, training_hyperparams_2 = _load_model(
        checkpoint_directory_2)

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    npz_obj_path_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            npz_path = npz_dataset_directory / category / model_id / "point_cloud.npz"
            obj_path = obj_dataset_directory / category / model_id / "models" / "model_normalized.obj"
            npz_obj_path_list.append((npz_path, obj_path))

    print(len(npz_obj_path_list))

    dataset = PointCloudAndMeshPairDataset(npz_obj_path_list)
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        if training_hyperparams_1.anneal_kld_weight:
            return training_hyperparams_1.loss_kld_final_weight
        return training_hyperparams_1.loss_kld_initial_weight

    loss_function_1 = LossFunction(
        tau=training_hyperparams_1.loss_tau,
        lam=training_hyperparams_1.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams_1.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams_1.eikonal_term_stddev)
    loss_function_2 = LossFunction(
        tau=training_hyperparams_2.loss_tau,
        lam=training_hyperparams_2.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams_2.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams_2.eikonal_term_stddev)

    grid_max_value = 1
    grid_min_value = -1

    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    latent_optimization_lr_list = [0.0001, 0.0001, 0.005]
    latent_optimization_num_samples_list = [100, 1000, 30000]

    for data_tuple in dataset:
        pc_data = data_tuple[0]
        mesh_data = data_tuple[1]

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        figure_path = output_directory / f"{category_id}_{model_id}.png"

        if model_id != "ea76015145946dffc0896a3cd08800fe":
            continue
        if category_id != "02958343":
            continue

        figsize_px = np.array([1500, 1500])
        dpi = 100
        figsize_inch = figsize_px / dpi
        fig = plt.figure(figsize=figsize_inch)

        levels = np.linspace(-1, 1, num=21)
        for row, (latent_optimization_lr,
                  latent_optimization_num_samples) in enumerate(
                      zip(latent_optimization_lr_list,
                          latent_optimization_num_samples_list)):
            data = minibatch_generator(
                [pc_data],
                min_num_context=latent_optimization_num_samples,
                max_num_context=latent_optimization_num_samples,
                min_num_target=0,
                max_num_target=0)

            # -------------------------------------------------------------------------
            # Ground truth
            # -------------------------------------------------------------------------
            mesh_surface_points = pc_data.vertices
            x = mesh_surface_points[:, 0]
            eps = 1e-3
            slice_indices = np.where((-eps < x) * (x < eps))
            slice_points = mesh_surface_points[slice_indices]
            x = slice_points[:, 2]
            y = slice_points[:, 1]

            plt.subplot(3, 3, 3 * row + 1)
            plt.scatter(x, y, c="k", marker="o", s=2)
            plt.subplot(3, 3, 3 * row + 2)
            plt.scatter(x, y, c="k", marker="o", s=2)

            # -------------------------------------------------------------------------
            # Model #1
            # -------------------------------------------------------------------------
            xx, yy, ff = _generate_sdf(model_1, loss_function_1)
            plt.subplot(3, 3, 3 * row + 1)
            cs = plt.contour(yy,
                             xx,
                             ff,
                             cmap="coolwarm_r",
                             vmin=-1,
                             vmax=1,
                             levels=levels)
            plt.clabel(cs, inline=1, fontsize=10)
            plt.xlim([-1, 1])
            plt.ylim([-1, 1])
            plt.xticks([])
            plt.yticks([])
            plt.gca().set_aspect("equal", adjustable="box")
            if row == 0:
                plt.title(
                    f"stddev={training_hyperparams_1.eikonal_term_stddev},lambda={training_hyperparams_1.loss_lambda}"
                )
            plt.ylabel(latent_optimization_num_samples)
            # -------------------------------------------------------------------------
            # Model #2
            # -------------------------------------------------------------------------
            xx, yy, ff = _generate_sdf(model_2, loss_function_2)
            plt.subplot(3, 3, 3 * row + 2)
            cs = plt.contour(yy,
                             xx,
                             ff,
                             cmap="coolwarm_r",
                             vmin=-1,
                             vmax=1,
                             levels=levels)
            plt.clabel(cs, inline=1, fontsize=10)
            plt.xlim([-1, 1])
            plt.ylim([-1, 1])
            plt.xticks([])
            plt.yticks([])
            plt.gca().set_aspect("equal", adjustable="box")
            if row == 0:
                plt.title(
                    f"stddev={training_hyperparams_2.eikonal_term_stddev},lambda={training_hyperparams_2.loss_lambda}"
                )
            plt.ylabel(latent_optimization_num_samples)

        plt.tight_layout()
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
