import json
import math
import shutil
from collections import defaultdict
from pathlib import Path

import colorful
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
from tabulate import tabulate

from pcn import click, load_model, mkdir
from pcn import point_cloud as pcl
from pcn.datasets.uniform_sparse_sampling import (Dataset, MeshDataDescription,
                                                  MinibatchGenerator,
                                                  PointCloudAndMeshPairDataset,
                                                  GtUniformPointCloudData)
from pcn.experiment import (Model, ModelHyperparameters,
                            TrainingHyperparameters, setup_model)


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return cd_term1 + cd_term2


def _compute_non_squared_chamfer_distance(gt_surface_points: np.ndarray,
                                          pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(distances)

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(distances)

    return cd_term1 + cd_term2


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    assert args_path.is_file()
    with open(args_path) as f:
        args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    assert args_path.is_file(), args_path
    with open(args_path) as f:
        args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


@click.group()
def client():
    pass


@client.command()
@click.argument("--result-directory", type=str, required=True)
def summarize(args):
    metric_list = ["chamfer_distance"]
    settings_list = ["num_input_points"]
    result, args, num_data = _summarize(args.result_directory)
    print(colorful.bold("Hyperparameters:"))
    print(args)

    # find worst-k
    top_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids)
                   if value is not None and value > 0]
        ranking = list(sorted(ranking, key=lambda item: item[0]))
        ranking = ranking[:top_k]
        print(colorful.bold(f"Top {top_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    # find worst-k
    worst_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids)
                   if value is not None and value > 0]
        ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
        ranking = ranking[:worst_k]
        print(colorful.bold(f"Worst {worst_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    data = {
        "learning_rate": args["learning_rate"],
        "num_data": num_data,
    }
    for metric_name in metric_list:
        values = [
            value for value in result[metric_name]
            if value is not None and value > 0
        ]
        values = np.array(values)
        ignore = values == np.nan
        print(values.dtype)
        print(np.count_nonzero(ignore), "nans")
        values = values[~ignore]
        print(values.shape, values.dtype)
        print(np.min(values), np.max(values))
        print(colorful.bold("Histogram:"))
        print(np.histogram(values))
        mean = np.mean(values)
        std = np.std(values)
        data[f"{metric_name}_mean"] = mean
        data[f"{metric_name}_std"] = std
    for key in settings_list:
        values = np.array(result[key])
        data[key] = values[0]

    tabulate_row = [data["num_data"]]
    for key in settings_list:
        value = data[key]
        tabulate_row.append(f"{value}")
    for metric_name in metric_list:
        mean = data[f"{metric_name}_mean"]
        std = data[f"{metric_name}_std"]
        tabulate_row.append(f"{mean:.06f} (±{std:.06f})")

    print(colorful.bold("Result:"))
    print(
        tabulate([tabulate_row],
                 headers=["# of data"] + settings_list + metric_list,
                 tablefmt="github"))


@client.command(name="chamfer_distance_mesh_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-path", type=str, required=True)
@click.argument("--obj-path", type=str, required=True)
@click.argument("--num-input-points", type=int, default=50)
@click.argument("--chamfer-distance-method",
                type=click.Choice(["symmetric", "non_squared_symmetric"]),
                required=True)
def chamfer_distance_mesh_data(args):
    device = torch.device("cuda", 0)
    npz_path = Path(args.npz_path)
    obj_path = Path(args.obj_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dest_path = output_directory / "args.json"
    if not dest_path.exists():
        shutil.copyfile(args_path, dest_path)

    npz_obj_path_list = [(npz_path, obj_path)]
    dataset = PointCloudAndMeshPairDataset(npz_obj_path_list)
    minibatch_generator = MinibatchGenerator(
        num_input_points=args.num_input_points, device=device)

    if args.chamfer_distance_method == "symmetric":
        compute_chamfer_distance = _compute_chamfer_distance
    elif args.chamfer_distance_method == "non_squared_symmetric":
        compute_chamfer_distance = _compute_non_squared_chamfer_distance
    else:
        raise NotImplementedError()

    for data_tuple in dataset:
        pc_data = data_tuple[0]
        mesh_data = data_tuple[1]
        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices
        gt_surface_points = _sample_surface_points(
            gt_faces, gt_vertices, model_hyperparams.num_dense_gt_points)

        batch = minibatch_generator([pc_data])

        with torch.no_grad():
            pred_coarse_points, pred_dense_points = model(batch.input_points)
        pred_dense_points = pred_dense_points.cpu().numpy()[0]

        print(batch.input_points[0].shape, "->", pred_dense_points.shape)

        chamfer_distance = compute_chamfer_distance(gt_surface_points,
                                                    pred_dense_points)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        print(category_id, model_id, chamfer_distance, flush=True)

        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples,
            "latent_optimization_num_samples":
            args.latent_optimization_num_samples,
            "grid_size": args.grid_size,
            "latent_optimization_iterations":
            args.latent_optimization_iterations,
            "latent_optimization_initial_lr":
            args.latent_optimization_initial_lr,
            "chamfer_distance_method": args.chamfer_distance_method,
        }
        with open(metrics_directory / f"{category_id}_{model_id}.json",
                  "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_surface_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-path", type=str, required=True)
@click.argument("--num-input-points", type=int, default=50)
@click.argument("--chamfer-distance-method",
                type=click.Choice(["symmetric", "non_squared_symmetric"]),
                required=True)
def chamfer_distance_surface_data(args):
    device = torch.device("cuda", 0)
    npz_path = Path(args.npz_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    dest_path = output_directory / "args.json"
    if not dest_path.exists():
        shutil.copyfile(args_path, dest_path)

    model_hyperparams = ModelHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    path_list = [npz_path]
    dataset = Dataset(path_list)
    minibatch_generator = MinibatchGenerator(
        num_input_points=args.num_input_points, device=device)

    if args.chamfer_distance_method == "symmetric":
        compute_chamfer_distance = _compute_chamfer_distance
    elif args.chamfer_distance_method == "non_squared_symmetric":
        compute_chamfer_distance = _compute_non_squared_chamfer_distance
    else:
        raise NotImplementedError()

    for pc_data in dataset:
        batch = minibatch_generator([pc_data])

        with torch.no_grad():
            pred_coarse_points, pred_dense_points = model(batch.input_points)
        pred_dense_points = pred_dense_points.cpu().numpy()[0]
        gt_surface_points = batch.gt_dense_points.cpu().numpy()[0]

        gt_surface_points = gt_surface_points / pc_data.scale - pc_data.offset
        pred_dense_points = pred_dense_points / pc_data.scale - pc_data.offset

        print(batch.input_points[0].shape, "->", pred_dense_points.shape)
        print(gt_surface_points.shape, "<->", pred_dense_points.shape)

        chamfer_distance = compute_chamfer_distance(gt_surface_points,
                                                    pred_dense_points)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        print(category_id, model_id, chamfer_distance, flush=True)

        result_path = metrics_directory / f"{category_id}_{model_id}.json"
        print(result_path)

        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "num_input_points": args.num_input_points,
            "chamfer_distance_num_samples":
            model_hyperparams.num_dense_gt_points,
            "chamfer_distance_method": args.chamfer_distance_method,
        }
        with open(result_path, "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


if __name__ == "__main__":
    client()
