import json
import math
from pathlib import Path
from typing import Callable, List

import colorful
import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree

from meta_learning_sdf import click, load_model, mkdir
from meta_learning_sdf.datasets.functions import (read_mesh_data,
                                                  read_uniform_point_cloud_data
                                                  )
from meta_learning_sdf.datasets.shapenet.uniform_sparse_sampling import (
    MeshData, MinibatchGenerator, CombinedDataset, UniformPointCloudData)
from meta_learning_sdf.experiments.baseline import (
    LatentOptimization, LatentOptimizationHyperparameters, LossFunction,
    MarchingCubes, Model, ModelHyperparameters, TrainingHyperparameters,
    setup_model)
from meta_learning_sdf.point_cloud import render_point_cloud, render_mesh


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return cd_term1 + cd_term2


def _compute_non_squared_chamfer_distance(gt_surface_points: np.ndarray,
                                          pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = distances

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = distances

    return cd_term1, cd_term2


class Plot:
    def __init__(self, model: Model, minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int],
                 optimize_latent_func: Callable[[torch.Tensor], torch.Tensor],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list
        self.optimize_latent = optimize_latent_func
        self.marching_cubes = marching_cubes_func

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = _look_at(eye=eye,
                                                       center=[0, 0, 0],
                                                       up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        figsize_px = np.array([600, 200 * len(self.num_point_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi

        # plot gt
        gt_faces = mesh_data.vertex_indices
        gt_vertices = (mesh_data.vertices + pc_data.offset) * pc_data.scale
        gt_vertices = (rotation_matrix @ gt_vertices.T).T + translation_vector
        gt_mesh = trimesh.Trimesh(vertices=gt_vertices, faces=gt_faces)
        gt_image = render_mesh(gt_mesh, camera_mag=1)

        fig, axes = plt.subplots(len(self.num_point_samples_list),
                                 3,
                                 figsize=figsize_inch)
        if len(self.num_point_samples_list) == 1:
            axes = [axes]

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        for row, num_point_samples in enumerate(self.num_point_samples_list):
            print(f"row {row+1} of {len(self.num_point_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            minibatch = self.minibatch_generator(
                [pc_data],
                min_num_context=num_point_samples,
                max_num_context=num_point_samples,
                min_num_target=30000,
                max_num_target=30000,
                random_state=random_state)

            gt_cd_points = minibatch.target_points
            gt_cd_points.requires_grad_(True)

            minibatch.target_points = None
            minibatch.target_normals = None

            # input point cloud
            input_points = minibatch.context_points_list[0][0].detach().cpu(
            ).numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            # make prediction
            with torch.no_grad():
                _, h_dist = self.model.encoder(
                    minibatch.context_points_list[0])
                h = h_dist.mean
            h = self.optimize_latent(minibatch, h)
            print("optimize_latent done.", flush=True)

            try:
                mc_vertices, mc_faces = self.marching_cubes(h)
                print("marching_cubes done.", flush=True)
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = render_mesh(mesh, camera_mag=1)
                axes[row][1].imshow(image)
                axes[row][1].set_xticks([])
                axes[row][1].set_yticks([])
                if row == 0:
                    axes[row][1].set_title("Reconstruction", fontsize=10)

            except ValueError:
                pass

            axes[row][2].imshow(gt_image)
            axes[row][2].set_xticks([])
            axes[row][2].set_yticks([])
            if row == 0:
                axes[row][2].set_title("Ground truth", fontsize=10)
        return fig


class _Plot:
    def __init__(self, model: Model, minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int],
                 optimize_latent_func: Callable[[torch.Tensor], torch.Tensor],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list
        self.optimize_latent = optimize_latent_func
        self.marching_cubes = marching_cubes_func

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        # setup
        camera_theta = math.pi / 3
        camera_phi = -math.pi / 4
        camera_r = 1
        eye = [
            camera_r * math.sin(camera_theta) * math.cos(camera_phi),
            camera_r * math.cos(camera_theta),
            camera_r * math.sin(camera_theta) * math.sin(camera_phi),
        ]
        rotation_matrix, translation_vector = _look_at(eye=eye,
                                                       center=[0, 0, 0],
                                                       up=[0, 1, 0])
        translation_vector = translation_vector[None, :]
        rotation_matrix = np.linalg.inv(rotation_matrix)

        # plot gt
        gt_faces = mesh_data.vertex_indices
        gt_vertices = (mesh_data.vertices + pc_data.offset) * pc_data.scale
        gt_vertices = (rotation_matrix @ gt_vertices.T).T + translation_vector
        gt_mesh = trimesh.Trimesh(vertices=gt_vertices, faces=gt_faces)
        gt_image = render_mesh(gt_mesh, camera_mag=1)

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        plt.figure()
        for row, num_point_samples in enumerate(self.num_point_samples_list):
            print(f"row {row+1} of {len(self.num_point_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            minibatch = self.minibatch_generator(
                [pc_data],
                min_num_context=num_point_samples,
                max_num_context=num_point_samples,
                min_num_target=30000,
                max_num_target=30000,
                random_state=random_state)

            gt_cd_points = minibatch.target_points
            gt_cd_points.requires_grad_(True)

            minibatch.target_points = None
            minibatch.target_normals = None

            # input point cloud
            input_points = minibatch.context_points_list[0][0].detach().cpu(
            ).numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)

            # make prediction
            with torch.no_grad():
                _, h_dist = self.model.encoder(
                    minibatch.context_points_list[0])
                h = h_dist.mean
            h = self.optimize_latent(minibatch, h)
            print("optimize_latent done.", flush=True)

            try:
                mc_vertices, mc_faces = self.marching_cubes(h)
                print("marching_cubes done.", flush=True)
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = render_mesh(mesh, camera_mag=1)
            except ValueError:
                pass

            gt_surface_points = _sample_surface_points(gt_faces, gt_vertices,
                                                       30000)
            pred_surface_points = _sample_surface_points(
                mc_faces, mc_vertices, 30000)

            (chamfer_distance_term_1,
             chamfer_distance_term_2) = _compute_non_squared_chamfer_distance(
                 gt_surface_points, pred_surface_points)

            hist_s1, bins = np.histogram(chamfer_distance_term_1)
            print(hist_s1)
            print(bins)

            print(len(hist_s1), len(bins[:-1]), len(bins[1:]))
            print("| ----------------------- | ---- |")
            for num, s, e in zip(hist_s1, bins[:-1], bins[1:]):
                print(f"| {s:.6f} ~ {e:.6f} | {num}    |")
            hist_s2, bins = np.histogram(chamfer_distance_term_2)
            print(hist_s2)
            print(bins)
            print("| ----------------------- | ---- |")
            for num, s, e in zip(hist_s2, bins[:-1], bins[1:]):
                print(f"| {s:.6f} ~ {e:.6f} | {num}    |")

            symmetric_chamfer_distance = np.mean(
                chamfer_distance_term_1) + np.mean(chamfer_distance_term_2)

            _h = h[:, None, :].expand((h.shape[0], 30000, h.shape[1]))
            distance = self.model.decoder(X=gt_cd_points, h=_h)
            distance = abs(distance).detach().cpu()

            hist_m, bins = np.histogram(distance)
            print(hist_m)
            print(bins)
            print("| ----------------------- | ---- |")
            for num, s, e in zip(hist_m, bins[:-1], bins[1:]):
                print(f"| {s:.6f} ~ {e:.6f} | {num}    |")

            plt.hist(
                [chamfer_distance_term_1, chamfer_distance_term_2, distance],
                np.linspace(0, 0.25, 200),
                label=[
                    "Σ_{j∈SDF} min_{i∈GT} d(i,j)",
                    "Σ_{j∈GT} min_{i∈SDF} d(i,j)", "mean_i|SDF(x_i)|"
                ])
            plt.legend(loc='upper right')
            plt.show()
            chamfer_distance = distance.mean().item()
            print(chamfer_distance)
            plt.show()
            exit()

        return fig


@click.group()
def client():
    pass


@client.command(name="plot_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-data-path", type=str, required=True)
@click.argument("--mesh-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--seed", type=int, default=0)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def plot_data(args,
              latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = CombinedDataset([(args.pc_data_path, args.mesh_data_path)],
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        return training_hyperparams.loss_kld_final_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    optimize_latent = LatentOptimization(model=model,
                                         loss_function=loss_function,
                                         params=latent_optimization_params)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    num_point_samples_list = [50, 100, 300, 1000]
    plot = Plot(model=model,
                minibatch_generator=minibatch_generator,
                num_point_samples_list=num_point_samples_list,
                optimize_latent_func=optimize_latent,
                marching_cubes_func=marching_cubes)

    for data_tuple in dataset:
        pc_data: UniformPointCloudData = data_tuple[0]
        mesh_data: MeshData = data_tuple[1]

        fig = plot(pc_data, mesh_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"Proposed Method\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


@client.command(name="plot_dataset")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-dataset-directory", type=str, required=True)
@click.argument("--mesh-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--seed", type=int, default=0)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def plot_dataset(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    pc_dataset_directory = Path(args.pc_dataset_directory)
    mesh_dataset_directory = Path(args.mesh_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    combined_data_path_list = []
    for category_id in split:
        model_id_list = sorted(split[category_id])
        for model_id in model_id_list:
            pc_data_path = pc_dataset_directory / category_id / model_id / "point_cloud.npz"
            mesh_data_path = mesh_dataset_directory / category_id / model_id / "models" / "model_normalized.obj"
            if not pc_data_path.exists():
                continue
            if not mesh_data_path.exists():
                continue
            combined_data_path_list.append((pc_data_path, mesh_data_path))

    print(len(combined_data_path_list))

    dataset = CombinedDataset(combined_data_path_list,
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        return training_hyperparams.loss_kld_final_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    optimize_latent = LatentOptimization(model=model,
                                         loss_function=loss_function,
                                         params=latent_optimization_params)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    num_point_samples_list = [50, 100, 300, 1000]
    plot = Plot(model=model,
                minibatch_generator=minibatch_generator,
                num_point_samples_list=num_point_samples_list,
                optimize_latent_func=optimize_latent,
                marching_cubes_func=marching_cubes)

    for data_tuple in dataset:
        pc_data: UniformPointCloudData = data_tuple[0]
        mesh_data: MeshData = data_tuple[1]

        fig = plot(pc_data, mesh_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"Proposed Method\nobject={category_id}_{model_id}\nlatent_optimization={args.latent_optimization_iterations}",
        #     fontsize=6)
        plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
