import json
import math
import shutil
from collections import defaultdict
from pathlib import Path

import colorful
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
from tabulate import tabulate

from implicit_geometric_regularization import click, load_model, mkdir
from implicit_geometric_regularization.datasets.functions import (
    read_mesh_data, read_uniform_point_cloud_data)
from implicit_geometric_regularization.datasets.shapenet.uniform_sparse_sampling import (
    CombinedDataset, MeshData, MinibatchGenerator, UniformPointCloudData,
    UniformPointCloudDataset)
from implicit_geometric_regularization.experiments.learning_shape_space import (
    AdaptiveLatentOptimization, DecoderHyperparameters, LatentOptimization,
    LatentOptimizationHyperparameters, LossFunction, MarchingCubes, Model,
    ModelHyperparameters, TrainingHyperparameters, setup_model)


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return cd_term1 + cd_term2


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    assert args_path.is_file()
    with open(args_path) as f:
        args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        with open(metrics_path) as f:
            metrics = json.load(f)
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


class ChamferDistance:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices

        minibatch = self.minibatch_generator(
            [pc_data], num_point_samples=self.latent_optimization_num_samples)
        minibatch.normals = None
        device = minibatch.points.get_device()

        initial_z = torch.normal(mean=0,
                                 std=0.01,
                                 size=(1, self.model.z_map.shape[1]),
                                 dtype=torch.float32).to(device)
        z = self.optimize_latent(minibatch, initial_z)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(z)
            mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
        except ValueError:
            return -1

        gt_surface_points = _sample_surface_points(
            gt_faces, gt_vertices, self.chamfer_distance_num_samples)
        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance = _compute_chamfer_distance(gt_surface_points,
                                                     pred_surface_points)

        return chamfer_distance


class SurfaceChamferDistance:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData):
        # generate gt surface points
        minibatch = self.minibatch_generator(
            [pc_data], num_point_samples=self.chamfer_distance_num_samples)
        gt_surface_points = minibatch.points.detach().cpu().numpy()[0]

        # generate input points
        minibatch = self.minibatch_generator(
            [pc_data], num_point_samples=self.latent_optimization_num_samples)
        minibatch.normals = None
        device = self.model.get_device()

        initial_z = torch.normal(mean=0,
                                 std=0.01,
                                 size=(1, self.model.z_map.shape[1]),
                                 dtype=torch.float32).to(device)
        z = self.optimize_latent(minibatch, initial_z)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(z)
        except ValueError:
            return None

        gt_surface_points = gt_surface_points / pc_data.scale - pc_data.offset
        mc_vertices = mc_vertices / pc_data.scale - pc_data.offset

        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance = _compute_chamfer_distance(gt_surface_points,
                                                     pred_surface_points)

        return chamfer_distance


@click.group()
def client():
    pass


@client.command()
@click.argument("--result-directory", type=str, required=True)
def summarize(args):
    metric_list = ["chamfer_distance"]
    settings_list = [
        "latent_optimization_num_samples", "latent_optimization_iterations"
    ]
    result, args, num_data = _summarize(args.result_directory)
    print(colorful.bold("Hyperparameters:"))
    print(args)

    # find top-k
    top_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids) if value > 0]
        ranking = list(sorted(ranking, key=lambda item: item[0]))
        ranking = ranking[:top_k]
        print(colorful.bold(f"Top {top_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    # find worst-k
    worst_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids) if value > 0]
        ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
        ranking = ranking[:worst_k]
        print(colorful.bold(f"Worst {worst_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    data = {
        "learning_rate": args["learning_rate"],
        "num_data": num_data,
    }
    for metric_name in metric_list:
        values = np.array(result[metric_name])
        ignore = values < 0
        values = values[~ignore]
        print(colorful.bold("Histogram:"))
        print(np.histogram(values))
        mean = np.mean(values)
        std = np.std(values)
        data[f"{metric_name}_mean"] = mean
        data[f"{metric_name}_std"] = std
    for key in settings_list:
        values = np.array(result[key])
        data[key] = values[0]

    tabulate_row = [data["num_data"]]
    for key in settings_list:
        value = data[key]
        tabulate_row.append(f"{value}")
    for metric_name in metric_list:
        mean = data[f"{metric_name}_mean"]
        std = data[f"{metric_name}_std"]
        tabulate_row.append(f"{mean:.06f} (±{std:.06f})")

    print(colorful.bold("Result:"))
    print(
        tabulate([tabulate_row],
                 headers=["# of data"] + settings_list + metric_list,
                 tablefmt="github"))


@client.command(name="chamfer_distance_mesh_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-data-path", type=str, required=True)
@click.argument("--mesh-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=512)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_mesh_data(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = CombinedDataset([(args.pc_data_path, args.mesh_data_path)],
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(
        with_normal=training_hyperparams.with_normal, device=device)

    def kld_weight_func():
        if training_hyperparams.anneal_kld_weight:
            return training_hyperparams.loss_kld_final_weight
        return training_hyperparams.loss_kld_initial_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        alpha=training_hyperparams.loss_alpha,
        num_eikonal_samples=0,
        eikonal_term_default_stddev=training_hyperparams.eikonal_term_stddev)

    optimize_latent = LatentOptimization(model=model,
                                         loss_function=loss_function,
                                         params=latent_optimization_params)
    compute_chamfer_distance = ChamferDistance(
        minibatch_generator=minibatch_generator,
        optimize_latent_func=optimize_latent,
        model=model,
        grid_size=args.grid_size,
        latent_optimization_num_samples=args.latent_optimization_num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples)

    for data_tuple in dataset:
        pc_data: UniformPointCloudData = data_tuple[0]
        mesh_data: MeshData = data_tuple[1]
        chamfer_distance = compute_chamfer_distance(pc_data, mesh_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        print(category_id, model_id, chamfer_distance, flush=True)
        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples,
            "latent_optimization_num_samples":
            args.latent_optimization_num_samples,
            "grid_size": args.grid_size,
            "latent_optimization_iterations":
            args.latent_optimization_iterations
        }
        with open(metrics_directory / f"{category_id}_{model_id}.json",
                  "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_surface_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.argument("--skip-if-exists", is_flag=True)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_surface_data(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    pc_data_path = Path(args.pc_data_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    parts = pc_data_path.parts
    category_id = parts[-3]
    model_id = parts[-2]
    result_path = metrics_directory / f"{category_id}_{model_id}.json"
    if args.skip_if_exists and result_path.exists():
        print("skip", pc_data_path)
        return

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "decoder.pt"
    else:
        model_path = checkpoint_directory / f"decoder.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model_hyperparams.num_data = 0
    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = UniformPointCloudDataset([(pc_data_path, None)])
    print(pc_data_path, flush=True)

    minibatch_generator = MinibatchGenerator(
        with_normal=training_hyperparams.with_normal, device=device)

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        alpha=training_hyperparams.loss_alpha,
        num_eikonal_samples=0,
        eikonal_term_default_stddev=training_hyperparams.eikonal_term_stddev)
    optimize_latent = AdaptiveLatentOptimization(
        model=model,
        loss_function=loss_function,
        params=latent_optimization_params)
    compute_chamfer_distance = SurfaceChamferDistance(
        minibatch_generator=minibatch_generator,
        optimize_latent_func=optimize_latent,
        model=model,
        grid_size=args.grid_size,
        latent_optimization_num_samples=latent_optimization_params.num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples)

    for pc_data in dataset:
        chamfer_distance = compute_chamfer_distance(pc_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        result_path = metrics_directory / f"{category_id}_{model_id}.json"

        print(category_id, model_id, chamfer_distance, flush=True)
        print(result_path)

        result = {
            "model_id":
            f"{category_id}_{model_id}",
            "chamfer_distance":
            chamfer_distance,
            "chamfer_distance_num_samples":
            args.chamfer_distance_num_samples,
            "grid_size":
            args.grid_size,
            "latent_optimization_num_samples":
            latent_optimization_params.num_samples,
            "latent_optimization_iterations":
            latent_optimization_params.iterations,
            "latent_optimization_initial_lr":
            latent_optimization_params.initial_lr,
            "latent_optimization_decrease_lr_every":
            latent_optimization_params.decrease_lr_every,
            "latent_optimization_lr_decay_factor":
            latent_optimization_params.lr_decay_factor,
        }
        with open(result_path, "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


if __name__ == "__main__":
    client()
