import json
import math
import shutil
from collections import defaultdict
from pathlib import Path

import colorful
import numpy as np
import pandas as pd
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
from tabulate import tabulate

from deep_sdf import click, load_model, mkdir
from deep_sdf.datasets.classes import Minibatch
from deep_sdf.datasets.functions import read_mesh_data, read_sdf_data
from deep_sdf.datasets.shapenet import (CombinedDataset, MeshData,
                                        MinibatchGenerator, SdfData,
                                        UniformPointCloudData,
                                        UniformPointCloudDataset)
from deep_sdf.experiments.learning_shape_space import (
    AdaptiveLatentOptimization, ChamferDistance, DecoderHyperparameters,
    LatentOptimization, LatentOptimizationHyperparameters, LossFunction,
    MarchingCubes, Model, ModelHyperparameters, TrainingHyperparameters,
    setup_model)


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return (cd_term1 + cd_term2) / 2


def _compute_non_squared_chamfer_distance(gt_surface_points: np.ndarray,
                                          pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(distances)

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(distances)

    return (cd_term1 + cd_term2) / 2


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    assert args_path.is_file()
    with open(args_path) as f:
        args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        with open(metrics_path) as f:
            metrics = json.load(f)
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


class SurfaceChamferDistance:
    def __init__(self, optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData):
        device = self.model.get_device()

        rand_indices = np.random.choice(len(pc_data.vertices),
                                        size=self.chamfer_distance_num_samples)
        gt_surface_points = pc_data.vertices[rand_indices]
        gt_surface_points = gt_surface_points / pc_data.scale - pc_data.offset

        rand_indices = np.random.choice(
            len(pc_data.vertices), size=self.latent_optimization_num_samples)
        input_points = pc_data.vertices[rand_indices]
        input_points = torch.from_numpy(input_points).to(device)
        input_points = input_points[None, :, :]
        distances = torch.zeros(input_points.shape[:2]).to(device)

        data = Minibatch(points=input_points,
                         distances=distances,
                         data_indices=None)
        initial_z = torch.normal(mean=0,
                                 std=0.01,
                                 size=(1, self.model.z_map.shape[1]),
                                 dtype=torch.float32).to(device)
        z = self.optimize_latent(data, initial_z)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(z)
            mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
        except ValueError:
            return None

        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance = _compute_chamfer_distance(gt_surface_points,
                                                     pred_surface_points)

        return chamfer_distance


@click.group()
def client():
    pass


@client.command()
@click.argument("--result-directory", type=str, required=True)
def summarize(args):
    metric_list = ["chamfer_distance"]
    settings_list = [
        "latent_optimization_num_samples", "latent_optimization_iterations"
    ]
    result, args, num_data = _summarize(args.result_directory)
    print(colorful.bold("Hyperparameters:"))
    print(args)

    # find top-k
    top_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids) if value > 0]
        ranking = list(sorted(ranking, key=lambda item: item[0]))
        ranking = ranking[:top_k]
        print(colorful.bold(f"Top {top_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    # find worst-k
    worst_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids) if value > 0]
        ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
        ranking = ranking[:worst_k]
        print(colorful.bold(f"Worst {worst_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    data = {
        "learning_rate": args["learning_rate"],
        "num_data": num_data,
    }
    for metric_name in metric_list:
        values = np.array(result[metric_name])
        ignore = values < 0
        values = values[~ignore]
        print(colorful.bold("Histogram:"))
        print(np.histogram(values))
        mean = np.mean(values)
        std = np.std(values)
        data[f"{metric_name}_mean"] = mean
        data[f"{metric_name}_std"] = std
    for key in settings_list:
        values = np.array(result[key])
        data[key] = values[0]

    tabulate_row = [data["num_data"]]
    for key in settings_list:
        value = data[key]
        tabulate_row.append(f"{value}")
    for metric_name in metric_list:
        mean = data[f"{metric_name}_mean"]
        std = data[f"{metric_name}_std"]
        tabulate_row.append(f"{mean:.06f} (±{std:.06f})")

    print(colorful.bold("Result:"))
    print(
        tabulate([tabulate_row],
                 headers=["# of data"] + settings_list + metric_list,
                 tablefmt="github"))


@client.command(name="summarize_by_category")
@click.argument("--result-directory", type=str, required=True)
def summarize_by_category(args):
    metric_list = ["chamfer_distance"]
    result, args, num_data = _summarize(args.result_directory)
    print(colorful.bold("Hyperparameters:"))
    print(args)

    latent_optimization_num_samples = result[
        "latent_optimization_num_samples"][0]
    latent_optimization_iterations = result["latent_optimization_iterations"][
        0]

    # find worst-k
    top_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids) if value > 0]
        ranking = list(sorted(ranking, key=lambda item: item[0]))
        ranking = ranking[:top_k]
        print(colorful.bold(f"Top {top_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    # find worst-k
    worst_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids) if value > 0]
        ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
        ranking = ranking[:worst_k]
        print(colorful.bold(f"Worst {worst_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    map_category_name = {
        "02691156": "Plane",
        "04256520": "Sofa",
        "04379243": "Table",
        "04530566": "Vessel",
        "02958343": "Car",
        "03001627": "Chair",
        "02933112": "Cabinet",
        "03636649": "Lamp",
        "02818832": "Bed",
        "02828884": "Bench",
        "02871439": "Bookshelf",
        "02924116": "Bus",
        "03467517": "Guitar",
        "03790512": "Motorbike",
        "03948459": "Pistol",
        "04225987": "Skateboard",
    }

    data_list = []
    model_ids = result["model_id"]
    category_id_set = set()
    for k, model_id in enumerate(model_ids):
        category_id = model_id.split("_")[0]
        object_id = model_id.split("_")[1]
        category_id_set.add(category_id)
        data = {"category_id": category_id, "object_id": object_id}
        skip = False
        for metric_name in metric_list:
            value = result[metric_name][k]
            data[metric_name] = value
            if value < 0:
                skip = True
        if skip:
            continue
        data_list.append(data)

    print(colorful.bold("Result:"))
    table = pd.DataFrame(data_list)
    for metric_name in metric_list:
        print(colorful.bold_green(metric_name))
        tabulate_header = [
            "latent_optimization_num_samples", "latent_optimization_iterations"
        ]
        tabulate_row = [
            latent_optimization_num_samples, latent_optimization_iterations
        ]
        for category_id in sorted(category_id_set):
            df = table[table["category_id"] == category_id]
            mean = df[metric_name].mean()
            tabulate_header.append(map_category_name[category_id])
            tabulate_row.append(f"{mean:.06f}")
        print(
            tabulate([tabulate_row],
                     headers=tabulate_header,
                     tablefmt="github"))


@client.command(name="summarize_all")
@click.argument("--result-root-directory", type=str, required=True)
def summarize_all(args):
    metric_list = ["chamfer_distance"]
    result_root_directory = Path(args.result_root_directory)

    data_list = []
    for result_directory in result_root_directory.iterdir():
        result, args, num_data = _summarize(result_directory, metric_list)
        data = {
            "max_num_target": args["max_num_target"],
            "kld_weight": args["loss_kld_initial_weight"],
            "learning_rate": args["learning_rate"],
            "num_samples": args["max_num_target"],
            "num_data": num_data,
        }
        for metric_name in metric_list:
            values = np.array(result[metric_name])
            ignore = values < 0
            values = values[~ignore]
            mean = np.mean(values)
            std = np.std(values)
            data[f"{metric_name}_mean"] = mean
            data[f"{metric_name}_std"] = std
        data_list.append(data)

    table = pd.DataFrame(data_list)

    lr_list = [0.0001, 0.00032]
    for lr in lr_list:
        df = table[table["learning_rate"] == lr]
        tabulate_rows = []
        print(lr)
        for index, row in df.iterrows():
            tabulate_row = [
                row["num_data"], row["num_samples"], row["kld_weight"]
            ]
            for metric_name in metric_list:
                mean = row[f"{metric_name}_mean"]
                std = row[f"{metric_name}_std"]
                tabulate_row.append(f"{mean:.06f} (±{std:.06f})")
            tabulate_rows.append(tabulate_row)

        print(
            tabulate(tabulate_rows,
                     headers=["# of data", "# of samples", "kld_weight"] +
                     metric_list,
                     tablefmt="github"))


@client.command(name="chamfer_distance_mesh_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--sdf-data-path", type=str, required=True)
@click.argument("--mesh-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=512)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_mesh_data(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    sdf_data_path = Path(args.sdf_data_path)
    mesh_data_path = Path(args.mesh_data_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = CombinedDataset([(sdf_data_path, mesh_data_path)],
                              [read_sdf_data, read_mesh_data])

    minibatch_generator = MinibatchGenerator(device=device)

    loss_function = LossFunction(
        lam=training_hyperparams.loss_lam,
        clamping_distance=training_hyperparams.clamping_distance)

    optimize_latent = LatentOptimization(model=model,
                                         loss_function=loss_function,
                                         params=latent_optimization_params)
    compute_chamfer_distance = ChamferDistance(
        loss_function=loss_function,
        minibatch_generator=minibatch_generator,
        optimize_latent_func=optimize_latent,
        model=model,
        grid_size=args.grid_size,
        latent_optimization_num_samples=args.latent_optimization_num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples)

    for data_tuple in dataset:
        sdf_data: SdfData = data_tuple[0]
        mesh_data: MeshData = data_tuple[1]
        chamfer_distance = compute_chamfer_distance(sdf_data, mesh_data)

        parts = str(sdf_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        print(category_id, model_id, chamfer_distance, flush=True)
        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples,
            "latent_optimization_num_samples":
            args.latent_optimization_num_samples,
            "grid_size": args.grid_size,
            "latent_optimization_iterations":
            args.latent_optimization_iterations
        }
        with open(metrics_directory / f"{category_id}_{model_id}.json",
                  "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)

    device = torch.device("cuda", 0)
    npz_dataset_directory = Path(args.npz_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    npz_obj_path_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            pc_data_path = npz_dataset_directory / category / model_id / "point_cloud.npz"
            mesh_data_path = obj_dataset_directory / category / model_id / "models" / "model_normalized.obj"
            if not pc_data_path.exists():
                continue
            if not mesh_data_path.exists():
                continue
            npz_obj_path_list.append((pc_data_path, mesh_data_path))

    print(len(npz_obj_path_list))

    dataset = SdfAndMeshPairDataset(npz_obj_path_list)
    minibatch_generator = MinibatchGenerator(
        num_points_in_tuple=model_hyperparams.encoder_num_points_in_tuple,
        device=device)

    loss_function = LossFunction(
        lam=training_hyperparams.loss_lam,
        clamping_distance=training_hyperparams.clamping_distance)

    lr = 0.001
    optimize_latent = LatentOptimization(
        model=model,
        loss_function=loss_function,
        lr=lr,
        max_iterations=args.latent_optimization_iterations)
    compute_chamfer_distance = ChamferDistance(
        loss_function=loss_function,
        minibatch_generator=minibatch_generator,
        optimize_latent_func=optimize_latent,
        model=model,
        grid_size=args.grid_size,
        latent_optimization_num_samples=args.latent_optimization_num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples)

    for data_index, data_tuple in enumerate(dataset.shuffle()):
        sdf_data: SdfData = data_tuple[0]
        mesh_data: MeshData = data_tuple[1]

        parts = str(sdf_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        chamfer_distance = compute_chamfer_distance(sdf_data, mesh_data)

        print(f"{data_index+1}/{len(dataset)}",
              category_id,
              model_id,
              chamfer_distance,
              flush=True)
        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples,
            "latent_optimization_num_samples":
            args.latent_optimization_num_samples,
            "grid_size": args.grid_size,
            "latent_optimization_iterations":
            args.latent_optimization_iterations
        }
        with open(metrics_directory / f"{category_id}_{model_id}.json",
                  "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_surface_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_surface_data(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    pc_data_path = Path(args.pc_data_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "decoder.pt"
    else:
        model_path = checkpoint_directory / f"decoder.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model_hyperparams.num_data = 0
    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)

    model.to(device)
    model.eval()

    dataset = UniformPointCloudDataset([pc_data_path])
    print(pc_data_path, flush=True)

    loss_function = LossFunction(
        lam=training_hyperparams.loss_lam,
        clamping_distance=training_hyperparams.clamping_distance)

    optimize_latent = AdaptiveLatentOptimization(
        model=model,
        loss_function=loss_function,
        params=latent_optimization_params)
    compute_chamfer_distance = SurfaceChamferDistance(
        optimize_latent_func=optimize_latent,
        model=model,
        grid_size=args.grid_size,
        latent_optimization_num_samples=latent_optimization_params.num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples)

    for pc_data in dataset:
        chamfer_distance = compute_chamfer_distance(pc_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        result_path = metrics_directory / f"{category_id}_{model_id}.json"

        print(category_id, model_id, chamfer_distance, flush=True)
        print(result_path)

        result = {
            "model_id":
            f"{category_id}_{model_id}",
            "chamfer_distance":
            chamfer_distance,
            "chamfer_distance_num_samples":
            args.chamfer_distance_num_samples,
            "grid_size":
            args.grid_size,
            "latent_optimization_num_samples":
            latent_optimization_params.num_samples,
            "latent_optimization_iterations":
            latent_optimization_params.iterations,
            "latent_optimization_initial_lr":
            latent_optimization_params.initial_lr,
            "latent_optimization_decrease_lr_every":
            latent_optimization_params.decrease_lr_every,
            "latent_optimization_lr_decay_factor":
            latent_optimization_params.lr_decay_factor,
        }
        with open(result_path, "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


if __name__ == "__main__":
    client()
