import json
import math
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List

import colorful
import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh

import meta_learning_sdf
import meta_learning_sdf_comparison
from meta_learning_sdf import click, mkdir
from meta_learning_sdf import point_cloud as pcl
from meta_learning_sdf.datasets.functions import (read_mesh_data,
                                                  read_sdf_data,
                                                  read_uniform_point_cloud_data
                                                  )
from meta_learning_sdf_comparison.args import (DeepSdfArgs, IgrArgs,
                                               OccNetArgs, PcnArgs,
                                               ProposedMethodArgs)
from meta_learning_sdf_comparison.datasets import (OccNetDataset, PcnDataset,
                                                   CombinedDataset)
from meta_learning_sdf_comparison.geometry import get_extrinsics


@click.group()
def client():
    pass


@client.command(name="plot_dataset")
@click.argument("--output-directory", type=str, required=True)
@click.argument("--sdf-dataset-directory", type=str, required=True)
@click.argument("--pc-dataset-directory", type=str, required=True)
@click.argument("--mesh-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--seed", type=int, default=0)
@click.hyperparameter_class(ProposedMethodArgs)
@click.hyperparameter_class(IgrArgs)
@click.hyperparameter_class(DeepSdfArgs)
@click.hyperparameter_class(PcnArgs)
@click.hyperparameter_class(OccNetArgs)
def plot_dataset(args, proposed_method_args: ProposedMethodArgs,
                 igr_args: IgrArgs, deepsdf_args: DeepSdfArgs,
                 pcn_args: PcnArgs, occnet_args: OccNetArgs):
    device = torch.device("cuda", 0)
    sdf_dataset_directory = Path(args.sdf_dataset_directory)
    pc_dataset_directory = Path(args.pc_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    print(proposed_method_args)
    print(igr_args)
    print(deepsdf_args)

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    category_model_pair_list = []
    for category_id in split:
        model_id_list = sorted(split[category_id])
        for model_id in model_id_list:
            category_model_pair_list.append((category_id, model_id))
    random.shuffle(category_model_pair_list)

    pc_path_list = []
    sdf_mesh_path_list = []
    pc_mesh_path_list = []
    pc_sdf_path_list = []
    for category_id, model_id in category_model_pair_list:
        sdf_path = sdf_dataset_directory / category_id / model_id / "sdf.npz"
        surface_points_path = pc_dataset_directory / category_id / model_id / "point_cloud.npz"
        obj_path = obj_dataset_directory / category_id / model_id / "models" / "model_normalized.obj"
        sdf_mesh_path_list.append((sdf_path, obj_path))
        pc_mesh_path_list.append((surface_points_path, obj_path))
        pc_path_list.append(surface_points_path)
        pc_sdf_path_list.append((sdf_path, surface_points_path))
    print(len(pc_mesh_path_list))

    pc_mesh_dataset = CombinedDataset(
        pc_mesh_path_list, [read_uniform_point_cloud_data, read_mesh_data])
    pcn_dataset = PcnDataset(pc_path_list)
    sdf_mesh_dataset = CombinedDataset(pc_mesh_path_list,
                                       [read_sdf_data, read_mesh_data])
    occnet_dataset = OccNetDataset(pc_sdf_path_list, memory_caching=False)

    num_point_samples_list = [50, 100, 300, 1000]
    plot_proposed_method = meta_learning_sdf_comparison.proposed_method.setup_plot_function(
        args=proposed_method_args,
        device=device,
        num_point_samples_list=num_point_samples_list,
        latent_optimization_initial_lr=proposed_method_args.
        latent_optimization_initial_lr,
        latent_optimization_iterations=0,
        grid_size=args.grid_size)
    plot_proposed_method.label = "Proposed Method \n w/o opt"
    plot_proposed_method_with_opt = meta_learning_sdf_comparison.proposed_method.setup_plot_function(
        args=proposed_method_args,
        device=device,
        num_point_samples_list=num_point_samples_list,
        latent_optimization_initial_lr=proposed_method_args.
        latent_optimization_initial_lr,
        latent_optimization_iterations=proposed_method_args.
        latent_optimization_iterations,
        grid_size=args.grid_size)
    plot_proposed_method_with_opt.label = "Proposed Method \n w/ opt"
    plot_igr = meta_learning_sdf_comparison.igr.setup_plot_function(
        args=igr_args,
        device=device,
        num_point_samples_list=num_point_samples_list,
        latent_optimization_initial_lr=igr_args.latent_optimization_initial_lr,
        latent_optimization_iterations=igr_args.latent_optimization_iterations,
        grid_size=args.grid_size)
    plot_deep_sdf = meta_learning_sdf_comparison.deep_sdf.setup_plot_function(
        args=deepsdf_args,
        device=device,
        num_sdf_samples_list=num_point_samples_list,
        latent_optimization_initial_lr=deepsdf_args.
        latent_optimization_initial_lr,
        latent_optimization_iterations=deepsdf_args.
        latent_optimization_iterations,
        grid_size=args.grid_size)
    plot_pcn = meta_learning_sdf_comparison.pcn.setup_plot_function(
        args=pcn_args,
        device=device,
        num_point_samples_list=num_point_samples_list)
    plot_occnet = meta_learning_sdf_comparison.occnet.setup_plot_function(
        args=occnet_args,
        device=device,
        num_point_samples_list=num_point_samples_list,
        grid_size=args.grid_size)

    for dataset_pair in zip(pc_mesh_dataset, sdf_mesh_dataset, pcn_dataset,
                            occnet_dataset):
        pc_mesh_data_pair = dataset_pair[0]
        sdf_mesh_data_pair = dataset_pair[1]
        pcn_data = dataset_pair[2]
        occnet_data = dataset_pair[3]
        pc_data: SurfacePointDataDescription = pc_mesh_data_pair[0]
        sdf_data: SurfacePointDataDescription = sdf_mesh_data_pair[0]
        mesh_data: MeshDataDescription = pc_mesh_data_pair[1]

        figsize_px = np.array([1600, 200 * len(num_point_samples_list)])
        dpi = 100
        figsize_inch = figsize_px / dpi
        fig, axes_rows = plt.subplots(len(num_point_samples_list),
                                      8,
                                      figsize=figsize_inch)

        if len(num_point_samples_list) == 1:
            axes_rows = [axes_rows]

        for axes in axes_rows:
            for ax in axes:
                ax.axis("off")

        # plot gt
        print("plot ground truth mesh", flush=True)
        rotation_matrix, translation_vector = get_extrinsics()
        gt_faces = mesh_data.vertex_indices
        gt_vertices = (mesh_data.vertices + pc_data.offset) * pc_data.scale
        gt_vertices = (rotation_matrix @ gt_vertices.T).T + translation_vector
        gt_mesh = trimesh.Trimesh(vertices=gt_vertices, faces=gt_faces)
        gt_image = pcl.render_mesh(gt_mesh, camera_mag=1)
        for row, num_point_samples in enumerate(num_point_samples_list):
            axes_rows[row][7].imshow(gt_image)
            axes_rows[row][7].set_xticks([])
            axes_rows[row][7].set_yticks([])
            if row == 0:
                axes_rows[row][7].set_title("Ground truth", fontsize=10)

        print("plot PCN", flush=True)
        plot_pcn(pc_data=pcn_data, axes=axes_rows, column=1)

        print("plot OccNet", flush=True)
        plot_occnet(pc_data=occnet_data, axes=axes_rows, column=2)

        print("plot DeepSdf", flush=True)
        plot_deep_sdf(sdf_data=sdf_data, axes=axes_rows, column=3)

        print("plot IGR", flush=True)
        plot_igr(pc_data=pc_data, axes=axes_rows, column=4)

        print("plot proposed method", flush=True)
        plot_proposed_method(pc_data=pc_data, axes=axes_rows, column=5)

        print("plot proposed method w/ opt", flush=True)
        plot_proposed_method_with_opt(pc_data=pc_data,
                                      axes=axes_rows,
                                      column=6)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]

        figure_path = output_directory / f"{category_id}_{model_id}.png"
        plt.tight_layout()
        # plt.suptitle(
        #     f"object={category_id}_{model_id}\nlatent_optimization={proposed_method_args.latent_optimization_iterations}",
        #     fontsize=6)
        # plt.subplots_adjust(top=0.92)
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        figure_path = output_directory / f"{category_id}_{model_id}.pdf"
        plt.savefig(figure_path, dpi=300, bbox_inches="tight", pad_inches=0.05)
        plt.close(fig)
        print(figure_path, flush=True)


if __name__ == "__main__":
    client()
