import json
import math
import shutil
from collections import defaultdict
from pathlib import Path

import colorful
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import pandas as pd
import seaborn as sns
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
from tabulate import tabulate

from pcn import click, load_model, mkdir
from pcn import point_cloud as pcl
from pcn.datasets.reproduction import (Dataset, MinibatchGenerator)
from pcn.experiment import (Model, ModelHyperparameters,
                            TrainingHyperparameters, setup_model)


def _read_pcd(path: str):
    with open(path, "r") as f:
        lines = [line.strip().split(" ") for line in f.readlines()]
    point3d_list = []
    is_data = False
    for line in lines:
        if line[0] == "DATA":
            is_data = True
            continue
        if is_data:
            point3d_list.append(
                [float(line[0]),
                 float(line[1]),
                 float(line[2])])
    return np.array(point3d_list)


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return (cd_term1 + cd_term2) / 2


def _compute_non_squared_chamfer_distance(gt_surface_points: np.ndarray,
                                          pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(distances)

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(distances)

    return (cd_term1 + cd_term2) / 2


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    assert args_path.is_file()
    with open(args_path) as f:
        args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


@click.group()
def client():
    pass


@client.command()
@click.argument("--result-directory", type=str, required=True)
def summarize(args):
    result, args, num_data = _summarize(args.result_directory)
    metric_name = "chamfer_distance"

    # find worst-k
    top_k = 10
    model_ids = result["model_id"]
    values = result[metric_name]
    ranking = [(value, model_id) for value, model_id in zip(values, model_ids)
               if value is not None and value > 0]
    ranking = list(sorted(ranking, key=lambda item: item[0]))
    ranking = ranking[:top_k]
    print(colorful.bold(f"Top {top_k} {metric_name}:"))
    for item in ranking:
        print(item[0], item[1])

    # find worst-k
    worst_k = 10
    model_ids = result["model_id"]
    values = result[metric_name]
    ranking = [(value, model_id) for value, model_id in zip(values, model_ids)
               if value is not None and value > 0]
    ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
    ranking = ranking[:worst_k]
    print(colorful.bold(f"Worst {worst_k} {metric_name}:"))
    for item in ranking:
        print(item[0], item[1])

    data = {
        "learning_rate": args["learning_rate"],
        "num_data": num_data,
    }
    values = [
        value for value in result[metric_name]
        if value is not None and value > 0
    ]
    values = np.array(values)
    ignore = values == np.nan
    print(values.dtype)
    print(np.count_nonzero(ignore), "nans")
    values = values[~ignore]
    print(values.shape, values.dtype)
    print(np.min(values), np.max(values))
    print(colorful.bold("Histogram:"))
    print(np.histogram(values))
    mean = np.mean(values)
    std = np.std(values)
    data[f"{metric_name}_mean"] = mean
    data[f"{metric_name}_std"] = std

    tabulate_row = [data["num_data"]]
    mean = data[f"{metric_name}_mean"]
    std = data[f"{metric_name}_std"]
    tabulate_row.append(f"{mean:.06f} (±{std:.06f})")

    print(colorful.bold("Result:"))
    print(
        tabulate([tabulate_row],
                 headers=["# of data", metric_name],
                 tablefmt="github"))


@client.command(name="summarize_by_category")
@click.argument("--result-directory", type=str, required=True)
def summarize_by_category(args):
    result, args, num_data = _summarize(args.result_directory)
    metric_name = "chamfer_distance"

    map_category_name = {
        "02691156": "Plane",
        "04256520": "Sofa",
        "04379243": "Table",
        "04530566": "Vessel",
        "02958343": "Car",
        "03001627": "Chair",
        "02933112": "Cabinet",
        "03636649": "Lamp",
        "02818832": "Bed",
        "02828884": "Bench",
        "02871439": "Bookshelf",
        "02924116": "Bus",
        "03467517": "Guitar",
        "03790512": "Motorbike",
        "03948459": "Pistol",
        "04225987": "Skateboard",
    }

    df = []
    model_ids = result["model_id"]
    category_id_set = set()
    for k, model_id in enumerate(model_ids):
        category_id = model_id.split("_")[0]
        object_id = model_id.split("_")[1]
        category_id_set.add(category_id)
        data = {"category_id": category_id, "object_id": object_id}
        value = result[metric_name][k]
        data[metric_name] = value
        df.append(data)
    df = pd.DataFrame(df)

    print(colorful.bold("Result:"))
    tabulate_header = []
    tabulate_row = []
    for category_id in sorted(category_id_set):
        df_category = df[df["category_id"] == category_id]
        mean = df_category[metric_name].mean()
        num_data = df_category[metric_name].count()
        print(map_category_name[category_id], num_data)
        tabulate_header.append(map_category_name[category_id])
        tabulate_row.append(f"{mean:.06f}")
    print(tabulate([tabulate_row], headers=tabulate_header, tablefmt="github"))


@client.command(name="chamfer_distance_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--partial-pcd-path", type=str, required=True)
@click.argument("--complete-pcd-path", type=str, required=True)
@click.argument("--chamfer-distance-method",
                type=click.Choice(["symmetric", "non_squared_symmetric"]),
                required=True)
def chamfer_distance_data(args):
    device = torch.device("cuda", 0)
    partial_pcd_path = Path(args.partial_pcd_path)
    complete_pcd_path = Path(args.complete_pcd_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dest_path = output_directory / "args.json"
    if not dest_path.exists():
        shutil.copyfile(args_path, dest_path)

    if args.chamfer_distance_method == "symmetric":
        compute_chamfer_distance = _compute_chamfer_distance
    elif args.chamfer_distance_method == "non_squared_symmetric":
        compute_chamfer_distance = _compute_non_squared_chamfer_distance
    else:
        raise NotImplementedError()

    complete_point_cloud = _read_pcd(complete_pcd_path)

    partial_point_cloud = _read_pcd(partial_pcd_path)
    partial_point_cloud = torch.from_numpy(partial_point_cloud).to(
        device).type(torch.float32)
    partial_point_cloud = partial_point_cloud[None, :, :]

    with torch.no_grad():
        pred_coarse_points, pred_dense_points = model(partial_point_cloud)
    pred_dense_points = pred_dense_points.cpu().numpy()[0]

    print(complete_point_cloud.shape, partial_point_cloud.shape)

    chamfer_distance = compute_chamfer_distance(
        gt_surface_points=complete_point_cloud,
        pred_surface_points=pred_dense_points)

    parts = complete_pcd_path.parts
    category_id = parts[-2]
    model_id = parts[-1].replace(".pcd", "")
    print(category_id, model_id, chamfer_distance, flush=True)

    # complete_pcn_pcd = o3d.geometry.PointCloud()
    # complete_pcn_pcd.points = o3d.utility.Vector3dVector(complete_point_cloud)
    # complete_pcn_pcd.colors = o3d.utility.Vector3dVector(
    #     np.repeat(np.array([[0, 1, 0]]), len(complete_point_cloud), axis=0))

    # partial_point_cloud = partial_point_cloud.cpu().numpy()[0]
    # partial_pcn_pcd = o3d.geometry.PointCloud()
    # partial_pcn_pcd.points = o3d.utility.Vector3dVector(partial_point_cloud)
    # partial_pcn_pcd.colors = o3d.utility.Vector3dVector(
    #     np.repeat(np.array([[0, 0, 1]]), len(partial_point_cloud), axis=0))

    # pred_pcn_pcd = o3d.geometry.PointCloud()
    # pred_pcn_pcd.points = o3d.utility.Vector3dVector(pred_dense_points)
    # pred_pcn_pcd.colors = o3d.utility.Vector3dVector(
    #     np.repeat(np.array([[1, 0, 1]]), len(pred_dense_points), axis=0))

    # o3d.visualization.draw_geometries(
    #     [complete_pcn_pcd, partial_pcn_pcd, pred_pcn_pcd])

    result = {
        "model_id": f"{category_id}_{model_id}",
        "chamfer_distance": chamfer_distance,
        "chamfer_distance_num_samples": len(complete_point_cloud),
        "chamfer_distance_method": args.chamfer_distance_method,
    }
    with open(metrics_directory / f"{category_id}_{model_id}.json", "w") as f:
        json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_on_training_dataset")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--split-path", type=str, required=True)
@click.argument("--chamfer-distance-method",
                type=click.Choice(["symmetric", "non_squared_symmetric"]),
                required=True)
def chamfer_distance_on_training_dataset(args):
    device = torch.device("cuda", 0)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset_directory = Path(args.dataset_directory)
    split_path = Path(args.split_path)
    assert split_path.is_file()

    with open(split_path) as f:
        split = json.load(f)

    data_path_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            data_path = dataset_directory / category / f"{model_id}.npz"
            data_path_list.append(data_path)

    dataset = Dataset(data_path_list,
                      num_coarse_points=model_hyperparams.num_coarse_gt_points,
                      num_dense_points=model_hyperparams.num_dense_gt_points)

    minibatch_generator = MinibatchGenerator(num_input_points=3000,
                                             device=device)

    if args.chamfer_distance_method == "symmetric":
        compute_chamfer_distance = _compute_chamfer_distance
    elif args.chamfer_distance_method == "non_squared_symmetric":
        compute_chamfer_distance = _compute_non_squared_chamfer_distance
    else:
        raise NotImplementedError()

    from torch_chamfer_distance import (ChamferDistance as
                                        ChamferDistanceFunction)
    compute_chamfer_distance_loss = ChamferDistanceFunction()
    for data in dataset:
        batch = minibatch_generator([data])
        with torch.no_grad():
            pred_coarse_points, pred_dense_points = model(batch.input_points)

        dist_1, dist_2 = compute_chamfer_distance_loss(batch.gt_dense_points,
                                                       pred_dense_points)
        loss_dense = torch.sqrt(dist_1).mean() + torch.sqrt(dist_2).mean()

        pred_dense_points = pred_dense_points.cpu().numpy()[0]
        complete_point_cloud = batch.gt_dense_points.cpu().numpy()[0]

        chamfer_distance = compute_chamfer_distance(
            gt_surface_points=complete_point_cloud,
            pred_surface_points=pred_dense_points)

        print(chamfer_distance, loss_dense.item())


if __name__ == "__main__":
    client()
