import json
import shutil
from collections import defaultdict
from pathlib import Path
from typing import Callable
import colorful
import numpy as np
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
from tabulate import tabulate

from occupancy_networks import click, load_model, mkdir
from occupancy_networks.datasets.surface.uniform_sparse_sampling import (
    MeshData, SurfacePointCloudData, SurfaceMeshPairDataset, SurfaceDataset)
from occupancy_networks.experiments.encoder import (MarchingCubes, Model,
                                                    ModelHyperparameters,
                                                    TrainingHyperparameters,
                                                    setup_model)


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


def _compute_asymmetric_chamfer_distance(gt_surface_points: np.ndarray,
                                         pred_surface_points: np.ndarray):
    # gt_points_kd_tree = KDTree(gt_surface_points)
    # distances, locations = gt_points_kd_tree.query(pred_surface_points)
    # cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return cd_term2


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return (cd_term1 + cd_term2) / 2


def _compute_non_squared_chamfer_distance(gt_surface_points: np.ndarray,
                                          pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(distances)

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(distances)

    return (cd_term1 + cd_term2) / 2


class ChamferDistance:
    def __init__(self, model: Model, grid_size: int, num_input_points: int,
                 chamfer_distance_num_samples: int,
                 compute_cd_func: Callable[[np.ndarray, np.ndarray], float]):
        self.model = model
        self.grid_size = grid_size
        self.num_input_points = num_input_points
        self.chamfer_distance_num_samples = chamfer_distance_num_samples
        self.compute_cd = compute_cd_func

    def __call__(self, pc_data: SurfacePointCloudData, mesh_data: MeshData):
        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices

        bbox = np.max(gt_vertices, axis=0) - np.min(gt_vertices, axis=0)
        max_edge_length = np.max(bbox)
        unit_1 = max_edge_length / 10

        rand_indices = np.random.choice(len(pc_data.surface_points),
                                        size=self.num_input_points)
        input_points = pc_data.surface_points[rand_indices]
        input_points = torch.from_numpy(input_points[None, :, :]).to(
            self.model.get_device()).type(torch.float32)
        with torch.no_grad():
            c = self.model.encode_inputs(input_points)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(c)
            mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
        except ValueError:
            return None

        gt_vertices /= unit_1
        mc_vertices /= unit_1

        gt_surface_points = _sample_surface_points(
            gt_faces, gt_vertices, self.chamfer_distance_num_samples)
        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance = self.compute_cd(gt_surface_points,
                                           pred_surface_points)

        return chamfer_distance


class SurfaceChamferDistance:
    def __init__(
            self, model: Model, grid_size: int, num_input_points: int,
            chamfer_distance_num_samples: int,
            chamfer_distance_func: Callable[[np.ndarray, np.ndarray], float]):
        self.model = model
        self.grid_size = grid_size
        self.num_input_points = num_input_points
        self.chamfer_distance_num_samples = chamfer_distance_num_samples
        self.compute_chamfer_distance = chamfer_distance_func

    def __call__(self, pc_data: SurfacePointCloudData):
        rand_indices = np.random.choice(len(pc_data.surface_points),
                                        size=self.chamfer_distance_num_samples)
        gt_surface_points = pc_data.surface_points[rand_indices]
        # gt_surface_points = gt_surface_points / pc_data.scale - pc_data.offset

        rand_indices = np.random.choice(len(pc_data.surface_points),
                                        size=self.num_input_points)
        input_points = pc_data.surface_points[rand_indices]
        input_points = torch.from_numpy(input_points).to(
            self.model.get_device()).type(torch.float32)
        input_points = input_points[None, :, :]

        with torch.no_grad():
            c = self.model.encode_inputs(input_points)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(c)
            # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
        except ValueError:
            return None
        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance = self.compute_chamfer_distance(
            gt_surface_points, pred_surface_points)

        return chamfer_distance


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    assert args_path.is_file()
    with open(args_path) as f:
        args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


@click.group()
def client():
    pass


@client.command()
@click.argument("--result-directory", type=str, required=True)
def summarize(args):
    metric_name = "chamfer_distance"
    result, args, num_data = _summarize(args.result_directory)

    # find worst-k
    top_k = 10
    model_ids = result["model_id"]
    values = result[metric_name]
    ranking = [(value, model_id) for value, model_id in zip(values, model_ids)
               if value is not None and value > 0]
    ranking = list(sorted(ranking, key=lambda item: item[0]))
    ranking = ranking[:top_k]
    print(colorful.bold(f"Top {top_k} {metric_name}:"))
    for item in ranking:
        print(item[0], item[1])

    # find worst-k
    worst_k = 10
    model_ids = result["model_id"]
    values = result[metric_name]
    ranking = [(value, model_id) for value, model_id in zip(values, model_ids)
               if value is not None and value > 0]
    ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
    ranking = ranking[:worst_k]
    print(colorful.bold(f"Worst {worst_k} {metric_name}:"))
    for item in ranking:
        print(item[0], item[1])

    data = {"num_data": num_data}
    values = [
        value for value in result[metric_name]
        if value is not None and value > 0
    ]
    values = np.array(values)
    print(num_data, "->", len(values))
    print(values.shape, values.dtype)
    print(np.min(values), np.max(values))
    print(colorful.bold("Histogram:"))
    print(np.histogram(values))
    mean = np.mean(values)
    std = np.std(values)
    data[f"{metric_name}_mean"] = mean
    data[f"{metric_name}_std"] = std

    tabulate_row = [data["num_data"]]
    mean = data[f"{metric_name}_mean"]
    std = data[f"{metric_name}_std"]
    tabulate_row.append(f"{mean:.06f} (±{std:.06f})")

    print(colorful.bold("Result:"))
    print(
        tabulate([tabulate_row],
                 headers=["# of data", metric_name],
                 tablefmt="github"))


@client.command(name="chamfer_distance_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-path", type=str, required=True)
@click.argument("--obj-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--num-input-points", type=int, default=300)
@click.argument("--chamfer-distance-num-samples", type=int, default=100000)
@click.argument("--chamfer-distance-method",
                type=click.Choice(["symmetric", "non_squared_symmetric"]),
                required=True)
def chamfer_distance_data(args):
    device = torch.device("cuda", 0)
    npz_path = Path(args.npz_path)
    obj_path = Path(args.obj_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    dest_path = output_directory / "args.json"
    if not dest_path.exists():
        shutil.copyfile(args_path, dest_path)

    model_hyperparams = ModelHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    npz_obj_path_list = [(npz_path, obj_path)]
    dataset = SurfaceMeshPairDataset(npz_obj_path_list)

    if args.chamfer_distance_method == "symmetric":
        compute_chamfer_distance = ChamferDistance(
            model=model,
            grid_size=args.grid_size,
            num_input_points=args.num_input_points,
            chamfer_distance_num_samples=args.chamfer_distance_num_samples,
            compute_cd_func=_compute_chamfer_distance)
    elif args.chamfer_distance_method == "non_squared_symmetric":
        compute_chamfer_distance = ChamferDistance(
            model=model,
            grid_size=args.grid_size,
            num_input_points=args.num_input_points,
            chamfer_distance_num_samples=args.chamfer_distance_num_samples,
            compute_cd_func=_compute_non_squared_chamfer_distance)
    else:
        raise NotImplementedError()

    for data_tuple in dataset:
        pc_data = data_tuple[0]
        mesh_data = data_tuple[1]
        chamfer_distance = compute_chamfer_distance(pc_data, mesh_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        print(category_id, model_id, chamfer_distance, flush=True)
        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples,
            "num_input_points": args.num_input_points,
            "grid_size": args.grid_size,
            "chamfer_distance_method": args.chamfer_distance_method,
        }
        with open(metrics_directory / f"{category_id}_{model_id}.json",
                  "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_surface_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--num-input-points", type=int, default=50)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.argument("--chamfer-distance-method",
                type=click.Choice(
                    ["asymmetric", "symmetric", "non_squared_symmetric"]),
                required=True)
def chamfer_distance_surface_data(args):
    device = torch.device("cuda", 0)
    npz_path = Path(args.npz_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    dest_path = output_directory / "args.json"
    if not dest_path.exists():
        shutil.copyfile(args_path, dest_path)

    model_hyperparams = ModelHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = SurfaceDataset([npz_path], memory_caching=False)

    if args.chamfer_distance_method == "symmetric":
        chamfer_distance_func = _compute_chamfer_distance
    elif args.chamfer_distance_method == "non_squared_symmetric":
        chamfer_distance_func = _compute_non_squared_chamfer_distance
    elif args.chamfer_distance_method == "asymmetric":
        chamfer_distance_func = _compute_asymmetric_chamfer_distance
    else:
        raise NotImplementedError()

    compute_chamfer_distance = SurfaceChamferDistance(
        model=model,
        grid_size=args.grid_size,
        num_input_points=args.num_input_points,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples,
        chamfer_distance_func=chamfer_distance_func)

    for pc_data in dataset:
        parts = pc_data.path.parts
        model_id = parts[-1]
        category_id = parts[-3]
        result_path = metrics_directory / f"{category_id}_{model_id}.json"

        chamfer_distance = compute_chamfer_distance(pc_data)
        print(category_id, model_id, chamfer_distance, flush=True)
        print(result_path)

        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples,
            "num_input_points": args.num_input_points,
            "grid_size": args.grid_size,
        }
        with open(result_path, "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


if __name__ == "__main__":
    client()
