import json
import random
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List

import colorful
import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh
import open3d as o3d

import meta_learning_sdf
from meta_learning_sdf import click, mkdir
from meta_learning_sdf import point_cloud as pcl
from meta_learning_sdf_comparison.args import (DeepSdfArgs, IgrArgs,
                                               ProposedMethodArgs, PcnArgs,
                                               OccNetArgs)
from meta_learning_sdf_comparison.datasets import (UniformPointCloudDataset,
                                                   DeepSdfDataset, PcnDataset,
                                                   OccNetDataset)
from meta_learning_sdf_comparison.geometry import get_extrinsics
import meta_learning_sdf_comparison


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


@click.group()
def client():
    pass


@client.command(name="plot_dataset")
@click.argument("--output-directory", type=str, required=True)
@click.argument("--sdf-dataset-directory", type=str, required=True)
@click.argument("--surface-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--seed", type=int, default=0)
@click.hyperparameter_class(ProposedMethodArgs)
@click.hyperparameter_class(IgrArgs)
@click.hyperparameter_class(DeepSdfArgs)
@click.hyperparameter_class(PcnArgs)
@click.hyperparameter_class(OccNetArgs)
def plot_dataset(args, proposed_method_args: ProposedMethodArgs,
                 igr_args: IgrArgs, deepsdf_args: DeepSdfArgs,
                 pcn_args: PcnArgs, occnet_args: OccNetArgs):
    device = torch.device("cuda", 0)
    sdf_dataset_directory = Path(args.sdf_dataset_directory)
    surface_dataset_directory = Path(args.surface_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    print(proposed_method_args)
    print(igr_args)
    print(deepsdf_args)

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    category_model_pair_list = []
    for category_id in split:
        model_id_list = sorted(split[category_id])
        for model_id in model_id_list:
            category_model_pair_list.append((category_id, model_id))
    random.shuffle(category_model_pair_list)

    include_object_list = [
        ("living_room_traj0_frei_png", "98"),
        ("living_room_traj2_frei_png", "394"),
        ("traj3_frei_png", "1213"),
    ]
    sdf_path_list = []
    pc_path_list = []
    pc_sdf_path_list = []
    for category_id, model_id in category_model_pair_list:
        # if (category_id, model_id) in include_object_list:
        sdf_path = sdf_dataset_directory / category_id / "depth" / f"{model_id}.npz"
        surface_points_path = surface_dataset_directory / category_id / "depth" / f"{model_id}.npz"
        sdf_path_list.append(sdf_path)
        pc_path_list.append(surface_points_path)
        pc_sdf_path_list.append((sdf_path, surface_points_path))
    print(len(pc_path_list))

    surface_dataset = UniformPointCloudDataset(pc_path_list)
    sdf_dataset = DeepSdfDataset(sdf_path_list)
    pcn_dataset = PcnDataset(pc_path_list)
    occnet_dataset = OccNetDataset(pc_sdf_path_list, memory_caching=False)

    num_point_samples_list = [50, 100, 300, 1000]
    plot_proposed_method = meta_learning_sdf_comparison.proposed_method.setup_plot_function(
        args=proposed_method_args,
        device=device,
        num_point_samples_list=num_point_samples_list,
        latent_optimization_initial_lr=proposed_method_args.
        latent_optimization_initial_lr,
        latent_optimization_iterations=0,
        grid_size=args.grid_size)
    plot_proposed_method.label = "Proposed Method \n w/o opt"
    plot_proposed_method_with_opt = meta_learning_sdf_comparison.proposed_method.setup_plot_function(
        args=proposed_method_args,
        device=device,
        num_point_samples_list=num_point_samples_list,
        latent_optimization_initial_lr=proposed_method_args.
        latent_optimization_initial_lr,
        latent_optimization_iterations=proposed_method_args.
        latent_optimization_iterations,
        grid_size=args.grid_size)
    plot_proposed_method_with_opt.label = "Proposed Method \n w/ opt"
    plot_igr = meta_learning_sdf_comparison.igr.setup_plot_function(
        args=igr_args,
        device=device,
        num_point_samples_list=num_point_samples_list,
        latent_optimization_initial_lr=igr_args.latent_optimization_initial_lr,
        latent_optimization_iterations=igr_args.latent_optimization_iterations,
        grid_size=args.grid_size)
    plot_deep_sdf = meta_learning_sdf_comparison.deep_sdf.setup_plot_function(
        args=deepsdf_args,
        device=device,
        num_sdf_samples_list=num_point_samples_list,
        latent_optimization_initial_lr=deepsdf_args.
        latent_optimization_initial_lr,
        latent_optimization_iterations=deepsdf_args.
        latent_optimization_iterations,
        grid_size=args.grid_size)
    plot_pcn = meta_learning_sdf_comparison.pcn.setup_plot_function(
        args=pcn_args,
        device=device,
        num_point_samples_list=num_point_samples_list)
    plot_occnet = meta_learning_sdf_comparison.occnet.setup_plot_function(
        args=occnet_args,
        device=device,
        num_point_samples_list=num_point_samples_list,
        grid_size=args.grid_size)

    # setup
    camera_theta = math.pi / 3
    camera_phi = -math.pi / 4
    camera_r = 1
    eye = [
        camera_r * math.sin(camera_theta) * math.cos(camera_phi),
        camera_r * math.cos(camera_theta),
        camera_r * math.sin(camera_theta) * math.sin(camera_phi),
    ]
    rotation_matrix, translation_vector = _look_at(eye=eye,
                                                   center=[0, 0, 0],
                                                   up=[0, 1, 0])
    translation_vector = translation_vector[None, :]
    rotation_matrix = np.linalg.inv(rotation_matrix)

    def cmap_binary(points: np.ndarray):
        points = points.copy()
        x = points[:, 0]
        scale = 1 / np.max(np.abs(x))
        x *= -scale
        intensity = 0.3 * (x + 1) / 2
        rgb = np.repeat(intensity[:, None], 3, axis=1)
        return rgb

    # def cmap_binary(points: np.ndarray):
    #     points = points.copy()
    #     z = points[:, 2]
    #     z = (z - z.min()) / (z.max() - z.min())
    #     cmap = plt.get_cmap("jet")
    #     rgba = cmap(z)
    #     rgb = rgba[:, :3]
    #     return rgb

    for dataset_pair in zip(surface_dataset, sdf_dataset, pcn_dataset,
                            occnet_dataset):
        pc_data = dataset_pair[0]
        sdf_data = dataset_pair[1]
        pcn_data = dataset_pair[2]
        occnet_data = dataset_pair[3]

        print("processing", pc_data.path)

        figsize_px = np.array([1600, 200 * len(num_point_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi
        fig, axes_rows = plt.subplots(len(num_point_samples_list),
                                      8,
                                      figsize=figsize_inch)

        if len(num_point_samples_list) == 1:
            axes_rows = [axes_rows]

        for axes in axes_rows:
            for ax in axes:
                ax.axis("off")

        # input point cloud
        gt_points = pc_data.vertices
        gt_normals = pc_data.vertex_normals
        gt_pcd = o3d.geometry.PointCloud()
        gt_pcd.points = o3d.utility.Vector3dVector(gt_points)
        gt_pcd.normals = o3d.utility.Vector3dVector(gt_normals)

        distances = gt_pcd.compute_nearest_neighbor_distance()
        avg_dist = np.mean(distances)
        radius = 1.5 * avg_dist
        gt_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
            gt_pcd, o3d.utility.DoubleVector([radius, radius * 2]))
        # o3d.visualization.draw_geometries([gt_mesh])

        gt_vertices = np.asarray(gt_mesh.vertices)
        gt_vertices = (rotation_matrix @ gt_vertices.T).T + translation_vector

        gt_mesh = trimesh.Trimesh(gt_vertices,
                                  np.asarray(gt_mesh.triangles),
                                  vertex_normals=np.asarray(
                                      gt_mesh.vertex_normals))
        gt_image = pcl.render_mesh(gt_mesh, camera_mag=1)

        # # input_points = input_points / pc_data.scale - pc_data.offset
        # points = (rotation_matrix @ gt_points.T).T + translation_vector
        # colors = cmap_binary(gt_points)
        # gt_image = pcl.render_point_cloud(points,
        #                                   colors,
        #                                   camera_mag=1,
        #                                   point_size=6)

        for row, num_point_samples in enumerate(num_point_samples_list):
            axes_rows[row][7].imshow(gt_image)
            axes_rows[row][7].set_xticks([])
            axes_rows[row][7].set_yticks([])
            if row == 0:
                axes_rows[row][7].set_title("Ground truth", fontsize=10)

        print("plot PCN", flush=True)
        plot_pcn(pc_data=pcn_data, axes=axes_rows, column=1)

        print("plot OccNet", flush=True)
        plot_occnet(pc_data=occnet_data, axes=axes_rows, column=2)

        print("plot DeepSdf", flush=True)
        plot_deep_sdf(sdf_data=sdf_data, axes=axes_rows, column=3)

        print("plot IGR", flush=True)
        plot_igr(pc_data=pc_data, axes=axes_rows, column=4)

        print("plot proposed method", flush=True)
        plot_proposed_method(pc_data=pc_data, axes=axes_rows, column=5)

        print("plot proposed method w/ opt", flush=True)
        plot_proposed_method_with_opt(pc_data=pc_data,
                                      axes=axes_rows,
                                      column=6)

        parts = pc_data.path.parts
        category_id = parts[-3]
        model_id = parts[-1].replace(".npz", "")

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"object={category_id}_{model_id}\nlatent_optimization={proposed_method_args.latent_optimization_iterations}",
        #     fontsize=6)
        # plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        figure_path = output_directory / f"{category_id}_{model_id}.pdf"
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
