import json
import math
import shutil
from collections import defaultdict
from pathlib import Path

import colorful
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
from tabulate import tabulate

from pcn import click, load_model, mkdir
from pcn import point_cloud as pcl
from pcn.experiment import (Model, ModelHyperparameters,
                            TrainingHyperparameters, setup_model)
import open3d as o3d


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return cd_term1 + cd_term2


def _compute_non_squared_chamfer_distance(gt_surface_points: np.ndarray,
                                          pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(distances)

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(distances)

    return cd_term1 + cd_term2


@click.group()
def client():
    pass


@client.command(name="evaluate_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--bbox-path", type=str, required=True)
@click.argument("--pcd-path", type=str, required=True)
def evaluate_data(args):
    device = torch.device("cuda", 0)
    bbox_path = Path(args.bbox_path)
    pcd_path = Path(args.pcd_path)
    output_root_directory = Path(args.output_directory)
    output_directory = output_root_directory / "output"
    mkdir(output_root_directory)
    mkdir(output_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dest_path = output_root_directory / "args.json"
    if not dest_path.exists():
        shutil.copyfile(args_path, dest_path)

    bbox = np.loadtxt(args.bbox_path)
    center = (bbox.min(0) + bbox.max(0)) / 2
    bbox -= center
    yaw = np.arctan2(bbox[3, 1] - bbox[0, 1], bbox[3, 0] - bbox[0, 0])
    rotation = np.array([[np.cos(yaw), -np.sin(yaw), 0],
                         [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]])
    bbox = np.dot(bbox, rotation)
    scale = bbox[3, 0] - bbox[0, 0]
    bbox /= scale

    partial_pcd = o3d.io.read_point_cloud(args.pcd_path)
    partial_points = np.array(partial_pcd.points)
    partial_points = np.dot(partial_points - center, rotation) / scale
    partial_points = np.dot(partial_points, [[1, 0, 0], [0, 0, 1], [0, 1, 0]])
    partial_points = np.dot(partial_points, [[0, 0, -1], [0, 1, 0], [1, 0, 0]])
    max_distance = np.linalg.norm(partial_points, axis=1).max() * 1.03
    partial_points /= max_distance

    input_points = torch.from_numpy(partial_points).to(device).type(
        torch.float32)
    input_points = input_points[None, :, :]
    with torch.no_grad():
        pred_coarse_points, pred_dense_points = model(input_points)
    pred_dense_points = pred_dense_points.cpu().numpy()[0]

    # pred_pcd = o3d.geometry.PointCloud()
    # pred_pcd.points = o3d.utility.Vector3dVector(pred_dense_points)

    # input_pcd = o3d.geometry.PointCloud()
    # partial_points[:, 0] += 1
    # input_pcd.points = o3d.utility.Vector3dVector(partial_points)

    # o3d.visualization.draw_geometries([input_pcd, pred_pcd])

    data_id = bbox_path.name.replace(".txt", "")
    output_path = output_directory / f"{data_id}.npz"
    print(data_id, "done", flush=True)
    np.savez(output_path,
             pred_points=pred_dense_points,
             scale=1 / max_distance)


@client.command()
@click.argument("--result-directory", type=str, required=True)
def summarize(args):
    result_directory = Path(args.result_directory)
    npz_files = list(result_directory.glob("*.npz"))
    for npz_file in npz_files:
        data_id = npz_file.name.replace(".npz", "")
        print(data_id)
        result = np.load(npz_file)
        print(result.files)


@client.command(name="check_data")
@click.argument("--bbox-path", type=str, required=True)
@click.argument("--pcd-path", type=str, required=True)
@click.argument("--npz-path", type=str, required=True)
def check_data(args):
    bbox = np.loadtxt(args.bbox_path)
    center = (bbox.min(0) + bbox.max(0)) / 2
    bbox -= center
    yaw = np.arctan2(bbox[3, 1] - bbox[0, 1], bbox[3, 0] - bbox[0, 0])
    rotation = np.array([[np.cos(yaw), -np.sin(yaw), 0],
                         [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]])
    bbox = np.dot(bbox, rotation)
    scale = bbox[3, 0] - bbox[0, 0]
    bbox /= scale
    print(bbox)

    pcd = o3d.io.read_point_cloud(args.pcd_path)
    partial_points = np.array(pcd.points)
    partial_points = np.dot(partial_points - center, rotation) / scale
    partial_points = np.dot(partial_points, [[1, 0, 0], [0, 0, 1], [0, 1, 0]])
    partial_points = np.dot(partial_points, [[0, 0, -1], [0, 1, 0], [1, 0, 0]])
    max_distance = np.linalg.norm(partial_points, axis=1).max() * 1.03
    partial_points /= max_distance
    partial_points[:, 1] += 1
    pcd.points = o3d.utility.Vector3dVector(partial_points)

    training_data = np.load(args.npz_path)
    vertices = training_data["vertices"]
    training_pcd = o3d.geometry.PointCloud()
    training_pcd.points = o3d.utility.Vector3dVector(vertices)

    o3d.visualization.draw_geometries([pcd, training_pcd])


if __name__ == "__main__":
    client()
