import json
import math
from pathlib import Path
from typing import Callable, List

import colorful
import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree

from meta_learning_sdf import click, load_model, mkdir
from meta_learning_sdf import point_cloud as pcl
from meta_learning_sdf.datasets.shapenet.uniform_sparse_sampling import (
    MeshData, Minibatch, MinibatchGenerator, UniformPointCloudDataset,
    UniformPointCloudData)
from meta_learning_sdf.experiments.baseline import (
    LatentOptimization, LossFunction, MarchingCubes, Model,
    ModelHyperparameters, TrainingHyperparameters, setup_model,
    LatentOptimizationHyperparameters)


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return cd_term1 + cd_term2


def _compute_non_squared_chamfer_distance(gt_surface_points: np.ndarray,
                                          pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = distances

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = distances

    return cd_term1, cd_term2


class NonSquaredChamferDistance:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices

        data = self.minibatch_generator(
            [pc_data],
            min_num_context=self.latent_optimization_num_samples,
            max_num_context=self.latent_optimization_num_samples,
            min_num_target=0,
            max_num_target=0)
        with torch.no_grad():
            initial_h, h_dist = self.model.encoder(data.context_points_list[0])
        h = self.optimize_latent(data, initial_h)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(h)
            mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
        except ValueError:
            return None

        gt_surface_points = _sample_surface_points(
            gt_faces, gt_vertices, self.chamfer_distance_num_samples)
        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance_term_1, chamfer_distance_term_2 = _compute_non_squared_chamfer_distance(
            gt_surface_points, pred_surface_points)

        hist, bins = np.histogram(chamfer_distance_term_1)
        print(hist)
        print(bins)
        hist, bins = np.histogram(chamfer_distance_term_2)
        print(hist)
        print(bins)

        symmetric_chamfer_distance = np.mean(
            chamfer_distance_term_1) + np.mean(chamfer_distance_term_2)
        hist, bins = np.histogram(symmetric_chamfer_distance)
        print(hist)
        print(bins)

        return chamfer_distance


class ChamferDistanceMeanSdf:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        minibatch = self.minibatch_generator(
            [pc_data],
            min_num_context=self.latent_optimization_num_samples,
            max_num_context=self.latent_optimization_num_samples,
            min_num_target=self.chamfer_distance_num_samples,
            max_num_target=self.chamfer_distance_num_samples,
            tuple_target=False)
        points = minibatch.target_points
        points.requires_grad_(True)

        minibatch.target_points = None
        minibatch.target_normals = None

        with torch.no_grad():
            initial_h, h_dist = self.model.encoder(
                minibatch.context_points_list[0])
        h = self.optimize_latent(minibatch, initial_h)

        _h = h[:, None, :].expand(
            (h.shape[0], self.chamfer_distance_num_samples, h.shape[1]))

        distance = self.model.decoder(X=points, h=_h)
        distance = abs(distance).detach().cpu()
        chamfer_distance = distance.mean().item()
        return chamfer_distance


class _Plot:
    def __init__(self, model: Model, minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int],
                 optimize_latent_func: Callable[[torch.Tensor], torch.Tensor],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list
        self.optimize_latent = optimize_latent_func
        self.marching_cubes = marching_cubes_func

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = _look_at(eye=eye,
                                                       center=[0, 0, 0],
                                                       up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        figsize_px = np.array([600, 200 * len(self.num_point_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi

        # plot gt
        gt_faces = mesh_data.vertex_indices
        gt_vertices = (mesh_data.vertices + pc_data.offset) * pc_data.scale
        gt_vertices = (rotation_matrix @ gt_vertices.T).T + translation_vector
        gt_mesh = trimesh.Trimesh(vertices=gt_vertices, faces=gt_faces)
        gt_image = pcl.render_mesh(gt_mesh, camera_mag=1)

        fig, axes = plt.subplots(len(self.num_point_samples_list),
                                 3,
                                 figsize=figsize_inch)
        if len(self.num_point_samples_list) == 1:
            axes = [axes]

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        for row, num_point_samples in enumerate(self.num_point_samples_list):
            print(f"row {row+1} of {len(self.num_point_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            minibatch = self.minibatch_generator(
                [pc_data],
                min_num_context=num_point_samples,
                max_num_context=num_point_samples,
                min_num_target=30000,
                max_num_target=30000,
                tuple_target=False,
                random_state=random_state)

            gt_cd_points = minibatch.target_points
            gt_cd_points.requires_grad_(True)

            minibatch.target_points = None
            minibatch.target_normals = None

            # input point cloud
            input_points = minibatch.context_points_list[0][0].detach().cpu(
            ).numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = pcl.render_point_cloud(points,
                                           colors,
                                           camera_mag=1,
                                           point_size=6)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            # make prediction
            with torch.no_grad():
                h, h_dist = self.model.encoder(
                    minibatch.context_points_list[0])
            h = self.optimize_latent(minibatch, h)
            print("optimize_latent done.", flush=True)

            try:
                mc_vertices, mc_faces = self.marching_cubes(h)
                print("marching_cubes done.", flush=True)
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = pcl.render_mesh(mesh, camera_mag=1)
                axes[row][1].imshow(image)
            except ValueError:
                pass
            axes[row][1].set_xticks([])
            axes[row][1].set_yticks([])
            if row == 0:
                axes[row][1].set_title("Reconstruction", fontsize=10)

            gt_surface_points = _sample_surface_points(gt_faces, gt_vertices,
                                                       30000)
            pred_surface_points = _sample_surface_points(
                mc_faces, mc_vertices, 30000)

            (chamfer_distance_term_1,
             chamfer_distance_term_2) = _compute_non_squared_chamfer_distance(
                 gt_surface_points, pred_surface_points)

            hist, bins = np.histogram(chamfer_distance_term_1)
            print(hist)
            print(bins)
            print(len(hist), len(bins[:-1]), len(bins[1:]))
            print("| ----------------------- | ---- |")
            for num, s, e in zip(hist, bins[:-1], bins[1:]):
                print(f"| {s:.6f} ~ {e:.6f} | {num}    |")
            hist, bins = np.histogram(chamfer_distance_term_2)
            print(hist)
            print(bins)
            print("| ----------------------- | ---- |")
            for num, s, e in zip(hist, bins[:-1], bins[1:]):
                print(f"| {s:.6f} ~ {e:.6f} | {num}    |")

            symmetric_chamfer_distance = np.mean(
                chamfer_distance_term_1) + np.mean(chamfer_distance_term_2)

            _h = h[:, None, :].expand((h.shape[0], 30000, h.shape[1]))
            distance = self.model.decoder(X=gt_cd_points, h=_h)
            distance = abs(distance).detach().cpu()

            hist, bins = np.histogram(distance)
            print(hist)
            print(bins)
            print("| ----------------------- | ---- |")
            for num, s, e in zip(hist, bins[:-1], bins[1:]):
                print(f"| {s:.6f} ~ {e:.6f} | {num}    |")

            chamfer_distance = distance.mean().item()
            print(chamfer_distance)
            axes[row][1].set_xlabel(
                f"symmetric={symmetric_chamfer_distance}\nmean_sdf={chamfer_distance}",
                fontsize=6)

            axes[row][2].imshow(gt_image)
            axes[row][2].set_xticks([])
            axes[row][2].set_yticks([])
            if row == 0:
                axes[row][2].set_title("Ground truth", fontsize=10)
        return fig


class Plot:
    def __init__(self, model: Model, minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int],
                 optimize_latent_func: Callable[[torch.Tensor], torch.Tensor],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list
        self.optimize_latent = optimize_latent_func
        self.marching_cubes = marching_cubes_func

    def __call__(self, pc_data: UniformPointCloudData):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = _look_at(eye=eye,
                                                       center=[0, 0, 0],
                                                       up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        def cmap_binary(points: np.ndarray):
            points = points.copy()
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        # input point cloud
        gt_points = pc_data.vertices
        # input_points = input_points / pc_data.scale - pc_data.offset
        points = (rotation_matrix @ gt_points.T).T + translation_vector
        colors = cmap_binary(gt_points)
        gt_image = pcl.render_point_cloud(points,
                                          colors,
                                          camera_mag=1,
                                          point_size=6)

        figsize_px = np.array([600, 200 * len(self.num_point_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi

        fig, axes = plt.subplots(len(self.num_point_samples_list),
                                 3,
                                 figsize=figsize_inch)
        if len(self.num_point_samples_list) == 1:
            axes = [axes]

        for row, num_point_samples in enumerate(self.num_point_samples_list):
            print(f"row {row+1} of {len(self.num_point_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            minibatch = self.minibatch_generator(
                [pc_data],
                min_num_context=num_point_samples,
                max_num_context=num_point_samples,
                min_num_target=0,
                max_num_target=0,
                tuple_target=False,
                random_state=random_state)
            context_points = torch.stack(minibatch.context_points_list)
            context_normals = torch.stack(minibatch.context_normals_list)
            context_points = context_points.view((-1, 3))
            context_normals = context_normals.view((-1, 3))

            minibatch.context_points_list = [context_points[None, :, :]]
            minibatch.context_normals_list = [context_normals[None, :, :]]

            # input point cloud
            input_points = context_points.detach().cpu().numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = pcl.render_point_cloud(points,
                                           colors,
                                           camera_mag=1,
                                           point_size=6)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            # make prediction
            with torch.no_grad():
                h, h_dist = self.model.encoder(context_points[None, :, :])
            h = self.optimize_latent(minibatch, h)
            print("optimize_latent done.", flush=True)

            try:
                mc_vertices, mc_faces = self.marching_cubes(h)
                print("marching_cubes done.", flush=True)
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = pcl.render_mesh(mesh, camera_mag=1)
                axes[row][1].imshow(image)
            except ValueError:
                pass
            axes[row][1].set_xticks([])
            axes[row][1].set_yticks([])
            if row == 0:
                axes[row][1].set_title("Reconstruction", fontsize=10)

            axes[row][2].imshow(gt_image)
            axes[row][2].set_xticks([])
            axes[row][2].set_yticks([])
            if row == 0:
                axes[row][2].set_title("Ground truth", fontsize=10)
        return fig


@click.group()
def client():
    pass


@client.command(name="plot_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-data", type=str, required=True)
@click.argument("--obj-data", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--latent-optimization-iterations", type=int, default=800)
@click.argument("--latent-optimization-initial-lr", type=float, default=0.001)
@click.argument("--seed", type=int, default=0)
def plot_data(args):
    device = torch.device("cuda", 0)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = PointCloudAndMeshPairDataset([(args.npz_data, args.obj_data)])
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        return training_hyperparams.loss_kld_final_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    optimize_latent = LatentOptimization(
        model=model,
        loss_function=loss_function,
        lr=args.latent_optimization_initial_lr,
        decrease_lr_every=args.latent_optimization_iterations // 2,
        max_iterations=args.latent_optimization_iterations)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    num_point_samples_list = [50, 100, 500, 1000, 5000]
    num_point_samples_list = [1000]
    plot = Plot(model=model,
                minibatch_generator=minibatch_generator,
                num_point_samples_list=num_point_samples_list,
                optimize_latent_func=optimize_latent,
                marching_cubes_func=marching_cubes)

    for data_tuple in dataset:
        pc_data: UniformPointCloudData = data_tuple[0]
        mesh_data: MeshData = data_tuple[1]

        fig = plot(pc_data, mesh_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"Proposed Method\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


@client.command(name="plot_dataset")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--seed", type=int, default=0)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def plot_dataset(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    dataset_directory = Path(args.dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    npz_path_list_train = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            npz_path = dataset_directory / category / "depth" / f"{model_id}.npz"
            npz_path_list_train.append(npz_path)

    print(len(npz_path_list_train))

    dataset = UniformPointCloudDataset(npz_path_list_train)
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        return training_hyperparams.loss_kld_final_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    optimize_latent = LatentOptimization(model=model,
                                         loss_function=loss_function,
                                         params=latent_optimization_params)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    num_point_samples_list = [50, 100, 500, 1000, 5000]
    num_point_samples_list = [50, 1000, 30000]
    # num_point_samples_list = [30000]
    plot = Plot(model=model,
                minibatch_generator=minibatch_generator,
                num_point_samples_list=num_point_samples_list,
                optimize_latent_func=optimize_latent,
                marching_cubes_func=marching_cubes)

    for pc_data in dataset.shuffle():
        print(colorful.bold(str(pc_data.path)), flush=True)

        fig = plot(pc_data)

        parts = pc_data.path.parts
        category_id = parts[-3]
        model_id = parts[-1].replace(".npz", "")

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"Proposed Method\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
