import json
import math
import shutil
from collections import defaultdict
from pathlib import Path
from hashlib import sha256

import colorful
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
from tabulate import tabulate

from deep_sdf import click, load_model, mkdir
from deep_sdf import point_cloud as pcl
from deep_sdf.datasets.shapenet import Minibatch
from deep_sdf.datasets.shapenet.partial_sampling import (
    MeshData, MinibatchGenerator, NpzDataset, PartialPointCloudData,
    PointCloudAndMeshPairDataset)
from deep_sdf.experiments.learning_shape_space import (
    AdaptiveLatentOptimization, DecoderHyperparameters, LatentOptimization,
    LossFunction, MarchingCubes, Model, ModelHyperparameters,
    TrainingHyperparameters, setup_model, LatentOptimizationHyperparameters)


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return cd_term1 + cd_term2


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    assert args_path.is_file()
    with open(args_path) as f:
        args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        with open(metrics_path) as f:
            metrics = json.load(f)
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


class ChamferDistance:
    def __init__(self, loss_function: LossFunction,
                 minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.loss_function = loss_function
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: PartialPointCloudData, mesh_data: MeshData):
        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices

        data = self.minibatch_generator(
            [pc_data],
            num_point_samples=self.latent_optimization_num_samples,
            num_viewpoint_samples=pc_data.num_viewpoints)
        points_list = data.points_list
        device = points_list[0].get_device()

        chamfer_distance_for_each_viewpoints = []
        for view_index in range(pc_data.num_viewpoints):
            points = points_list[view_index]
            distances = torch.zeros(points.shape[:2]).to(points)
            data = Minibatch(points=points,
                             distances=distances,
                             data_indices=None)
            initial_z = torch.normal(mean=0,
                                     std=0.01,
                                     size=(1, self.model.z_map.shape[1]),
                                     dtype=torch.float32).to(device)
            z = self.optimize_latent(data, initial_z)

            grid_max_value = 1
            grid_min_value = -1
            marching_cubes = MarchingCubes(model=self.model,
                                           grid_size=self.grid_size,
                                           grid_max_value=grid_max_value,
                                           grid_min_value=grid_min_value)
            try:
                mc_vertices, mc_faces = marching_cubes(z)
                mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                gt_surface_points = _sample_surface_points(
                    gt_faces, gt_vertices, self.chamfer_distance_num_samples)
                pred_surface_points = _sample_surface_points(
                    mc_faces, mc_vertices, self.chamfer_distance_num_samples)

                chamfer_distance = _compute_chamfer_distance(
                    gt_surface_points, pred_surface_points)
            except ValueError:
                chamfer_distance = None

            chamfer_distance_for_each_viewpoints.append({
                "view_index":
                view_index,
                "chamfer_distance":
                chamfer_distance
            })
            print("view:",
                  view_index,
                  "chamfer_distance:",
                  chamfer_distance,
                  flush=True)

        return chamfer_distance_for_each_viewpoints


class SurfaceChamferDistance:
    def __init__(self,
                 minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization,
                 model: Model,
                 grid_size: int,
                 latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int,
                 num_viewpoint_samples: int,
                 seed: int = 0):
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples
        self.num_viewpoint_samples = num_viewpoint_samples
        self.seed = seed

    def __call__(self, pc_data: PartialPointCloudData):
        device = self.model.get_device()
        if self.seed == -1:
            random_state = np.random.RandomState(None)
        else:
            torch.manual_seed(self.seed)
            random_state = np.random.RandomState(self.seed)
        data = self.minibatch_generator(
            [pc_data],
            num_point_samples=self.latent_optimization_num_samples,
            num_viewpoint_samples=pc_data.num_viewpoints,
            random_state=random_state)
        points_list = data.points_list
        gt_surface_points = pc_data.vertices / pc_data.scale - pc_data.offset

        if self.seed == -1:
            seed = pc_data.path.parts[-3] + "/" + pc_data.path.parts[-2]
            print(seed)
            hash = sha256(seed.encode())
            seed = np.frombuffer(hash.digest(), dtype=np.uint32)
            random_state = np.random.RandomState(seed)
        else:
            torch.manual_seed(self.seed)
            random_state = np.random.RandomState(self.seed)
        view_indices = random_state.choice(len(data.points_list),
                                           size=self.num_viewpoint_samples,
                                           replace=False)
        print(view_indices)
        chamfer_distance_for_each_viewpoints = []
        for view_index in view_indices:
            points = points_list[view_index]
            distances = torch.zeros(points.shape[:2]).to(device)
            input_data = Minibatch(points=points,
                                   distances=distances,
                                   data_indices=None)
            initial_z = torch.normal(mean=0,
                                     std=0.01,
                                     size=(1, self.model.z_map.shape[1]),
                                     dtype=torch.float32).to(device)
            z = self.optimize_latent(input_data, initial_z)

            grid_max_value = 1
            grid_min_value = -1
            marching_cubes = MarchingCubes(model=self.model,
                                           grid_size=self.grid_size,
                                           grid_max_value=grid_max_value,
                                           grid_min_value=grid_min_value)
            try:
                mc_vertices, mc_faces = marching_cubes(z)
                mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                pred_surface_points = _sample_surface_points(
                    mc_faces, mc_vertices, self.chamfer_distance_num_samples)
                chamfer_distance = _compute_chamfer_distance(
                    gt_surface_points, pred_surface_points)
            except ValueError:
                chamfer_distance = None

            chamfer_distance_for_each_viewpoints.append({
                "view_index":
                int(view_index),
                "chamfer_distance":
                chamfer_distance
            })
            print("view:",
                  view_index,
                  "chamfer_distance:",
                  chamfer_distance,
                  flush=True)

        return chamfer_distance_for_each_viewpoints


@click.group()
def client():
    pass


@client.command()
@click.argument("--result-directory", type=str, required=True)
def summarize(args):
    result, args, num_data = _summarize(args.result_directory)

    latent_optimization_num_samples = result[
        "latent_optimization_num_samples"][0]

    data_list = []
    for model_id, chamfer_distance_for_each_viewpoints in zip(
            result["model_id"],
            result["chamfer_distance_for_each_viewpoints"]):
        for row in chamfer_distance_for_each_viewpoints:
            view_index = row["view_index"]
            chamfer_distance = row["chamfer_distance"]
            if chamfer_distance is None:
                print(model_id, view_index, "failed")
                continue
            data_list.append({
                "model_id": model_id,
                "view_index": view_index,
                "chamfer_distance": chamfer_distance
            })

    df = pd.DataFrame(data_list)
    grouped = df.groupby("view_index")
    print(grouped)
    print(grouped.size())

    mean = df.mean(numeric_only=True)["chamfer_distance"]
    std = df.std(numeric_only=True)["chamfer_distance"]
    print(
        tabulate(
            [[latent_optimization_num_samples, f"{mean:.06f} (±{std:.06f})"]],
            headers=["# of context", "chamfer_distance"],
            tablefmt="github"))


@client.command(name="chamfer_distance_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-path", type=str, required=True)
@click.argument("--obj-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=512)
@click.argument("--latent-optimization-iterations", type=int, default=800)
@click.argument("--latent-optimization-initial-lr", type=float, default=0.005)
@click.argument("--latent-optimization-num-samples", type=int, default=30000)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
def chamfer_distance_data(args):
    device = torch.device("cuda", 0)
    npz_path = Path(args.npz_path)
    obj_path = Path(args.obj_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    npz_obj_path_list = [(npz_path, obj_path)]
    dataset = PointCloudAndMeshPairDataset(npz_obj_path_list)
    minibatch_generator = MinibatchGenerator(device=device)

    loss_function = LossFunction(
        lam=training_hyperparams.loss_lam,
        clamping_distance=training_hyperparams.clamping_distance)

    optimize_latent = LatentOptimization(
        model=model,
        loss_function=loss_function,
        lr=args.latent_optimization_initial_lr,
        decrease_lr_every=args.latent_optimization_iterations // 2,
        max_iterations=args.latent_optimization_iterations)
    compute_chamfer_distance = ChamferDistance(
        loss_function=loss_function,
        minibatch_generator=minibatch_generator,
        optimize_latent_func=optimize_latent,
        model=model,
        grid_size=args.grid_size,
        latent_optimization_num_samples=args.latent_optimization_num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples)

    for raw_data_pair in dataset:
        pc_data: PartialPointCloudData = raw_data_pair[0]
        mesh_data: MeshData = raw_data_pair[1]

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        result_path = metrics_directory / f"{category_id}_{model_id}.json"
        if result_path.exists():
            with open(result_path) as f:
                metrics = json.load(f)
                if metrics["chamfer_distance_for_each_viewpoints"] != -1:
                    print(category_id, model_id, "skipped", flush=True)
                    continue

        chamfer_distance_for_each_viewpoints = compute_chamfer_distance(
            pc_data, mesh_data)
        print(category_id, model_id, flush=True)

        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance_for_each_viewpoints":
            chamfer_distance_for_each_viewpoints,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples,
            "latent_optimization_num_samples":
            args.latent_optimization_num_samples,
            "grid_size": args.grid_size,
            "latent_optimization_iterations":
            args.latent_optimization_iterations,
            "latent_optimization_initial_lr":
            args.latent_optimization_initial_lr
        }
        with open(result_path, "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_surface_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--npz-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--num-viewpoint-samples", type=int, default=5)
@click.argument("--seed", type=int, default=0)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_surface_data(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    npz_path = Path(args.npz_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = NpzDataset([npz_path])
    minibatch_generator = MinibatchGenerator(device=device)
    print(npz_path, flush=True)

    loss_function = LossFunction(
        lam=training_hyperparams.loss_lam,
        clamping_distance=training_hyperparams.clamping_distance)

    optimize_latent = AdaptiveLatentOptimization(
        model=model,
        loss_function=loss_function,
        params=latent_optimization_params)
    compute_chamfer_distance = SurfaceChamferDistance(
        optimize_latent_func=optimize_latent,
        minibatch_generator=minibatch_generator,
        num_viewpoint_samples=args.num_viewpoint_samples,
        model=model,
        grid_size=args.grid_size,
        latent_optimization_num_samples=latent_optimization_params.num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples,
        seed=args.seed)

    for pc_data in dataset:
        chamfer_distance_for_each_viewpoints = compute_chamfer_distance(
            pc_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        result_path = metrics_directory / f"{category_id}_{model_id}.json"

        print(category_id, model_id, "done", flush=True)
        print(result_path)

        result = {
            "model_id":
            f"{category_id}_{model_id}",
            "chamfer_distance_for_each_viewpoints":
            chamfer_distance_for_each_viewpoints,
            "chamfer_distance_num_samples":
            args.chamfer_distance_num_samples,
            "grid_size":
            args.grid_size,
            "latent_optimization_num_samples":
            latent_optimization_params.num_samples,
            "latent_optimization_iterations":
            latent_optimization_params.iterations,
            "latent_optimization_initial_lr":
            latent_optimization_params.initial_lr,
            "latent_optimization_decrease_lr_every":
            latent_optimization_params.decrease_lr_every,
            "latent_optimization_lr_decay_factor":
            latent_optimization_params.lr_decay_factor,
        }
        with open(result_path, "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


if __name__ == "__main__":
    client()
