import json
from collections import defaultdict
from decimal import ROUND_HALF_UP, Decimal
from pathlib import Path
import math
import matplotlib.pylab as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tabulate import tabulate


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    args = None
    if args_path.is_file():
        with open(args_path) as f:
            args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    assert len(metrics_path_list) > 0, result_directory
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


proposed_lo_0_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_0/50_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_0/100_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_0/300_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_0/1000_samples/lr_0.005/ec27d70e7cf0/",
]

proposed_lo_800_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/50_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/100_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/300_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/1000_samples/lr_0.005/ec27d70e7cf0/",
]

gauss_densification_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories_novel/50_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories_novel/100_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories_novel/300_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories_novel/1000_inputs",
]

pcn_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories_novel/epoch_5000/50_inputs/81d2edd77142",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories_novel/epoch_5000/100_inputs/81d2edd77142",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories_novel/epoch_5000/300_inputs/81d2edd77142",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories_novel/epoch_5000/1000_inputs/81d2edd77142",
]

occnet_directories = [
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories_novel/epoch_5000/50_inputs/b6ae866d88ba/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories_novel/epoch_5000/100_inputs/b6ae866d88ba/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories_novel/epoch_5000/300_inputs/b6ae866d88ba/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories_novel/epoch_5000/1000_inputs/b6ae866d88ba/",
]

igr_directories = [
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/50_context/lr_0.05/eb771fe8d137/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/100_context/lr_0.05/eb771fe8d137/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/300_context/lr_0.05/eb771fe8d137/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/1000_context/lr_0.05/eb771fe8d137/",
]

deepsdf_directories = [
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/50_context/lr_0.05/fd4c9568e7cc/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/100_context/lr_0.05/fd4c9568e7cc/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/300_context/lr_0.05/fd4c9568e7cc/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/1000_context/lr_0.05/fd4c9568e7cc/",
]

map_method_chamfer_distance = {}
method_name_list = [
    "gauss_densification",
    "pcn",
    "occnet",
    "deepsdf",
    "igr",
    "proposed_lo_0",
    "proposed_lo_800",
]
for method_name in method_name_list:
    map_method_chamfer_distance[method_name] = {}
model_id_set = set()

for result_directory in gauss_densification_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["gauss_densification"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["gauss_densification"][num_context][
            model_id] = chamfer_distance

for result_directory in pcn_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["pcn"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["pcn"][num_context][
            model_id] = chamfer_distance

for result_directory in occnet_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["occnet"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["occnet"][num_context][
            model_id] = chamfer_distance

for result_directory in deepsdf_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["deepsdf"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["deepsdf"][num_context][
            model_id] = chamfer_distance

for result_directory in igr_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["igr"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["igr"][num_context][
            model_id] = chamfer_distance

for result_directory in proposed_lo_0_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_0"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_0"][num_context][
            model_id] = chamfer_distance
        model_id_set.add(model_id)

for result_directory in proposed_lo_800_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_800"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_800"][num_context][
            model_id] = chamfer_distance

map_method_name_label = {
    "proposed_lo_0": "Proposed Method w/o opt",
    "proposed_lo_800": "Proposed Method w/ opt",
    "gauss_densification": "Gauss Densification",
    "pcn": "PCN",
    "occnet": "OccNet",
    "igr": "IGR",
    "deepsdf": "DeepSDF",
}
data_list = []
base_method_name = "proposed_lo_0"
for model_id in model_id_set:
    for num_context in [50, 100, 300, 1000]:
        base_chamfer_distance = map_method_chamfer_distance[base_method_name][
            num_context][model_id]
        if base_chamfer_distance is None:
            print(model_id)
            continue
        for method_name in method_name_list:
            if model_id not in map_method_chamfer_distance[method_name][
                    num_context]:
                print(method_name, model_id)
                continue
            chamfer_distance = map_method_chamfer_distance[method_name][
                num_context][model_id]
            if chamfer_distance is None:
                print(method_name, model_id)
                continue
            relative_cd = chamfer_distance / base_chamfer_distance
            data_list.append({
                "Chamfer Distance x1000": chamfer_distance * 1000,
                "Normalized Chamfer Distance": relative_cd,
                "Method": map_method_name_label[method_name],
                "model_id": model_id,
                "Number of Context": num_context,
            })

df = pd.DataFrame(data_list)

# grouped = df.groupby(["Number of Context", "Method"])

# df.plot(x="Number of Context",
#         y="Log Chamfer Distance +10",
#         kind='bar',
#         alpha=0.75,
#         rot=0)
# plt.show()
# mean = grouped.mean()
# std = grouped.std()
# print(mean)
# print(std)
# print(mean.columns)
# print(mean.axes)


def change_width(ax, new_value):
    for patch in ax.patches:
        current_width = patch.get_width()
        diff = current_width - new_value

        # we change the bar width
        patch.set_width(new_value)

        # we recenter the bar
        patch.set_x(patch.get_x() + diff * .5)


figsize_px = np.array([2000, 300])
dpi = 100
figsize_inch = figsize_px / dpi
plt.figure(figsize=figsize_inch)
plt.subplot(1, 2, 1)
ax = sns.barplot(x="Number of Context",
                 y="Chamfer Distance x1000",
                 hue="Method",
                 errwidth=1,
                 ci="sd",
                 data=df)
# ax.set_ylim([0, 5])
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)

change_width(ax, .04)

plt.subplot(1, 2, 2)
ax = sns.barplot(x="Number of Context",
                 y="Normalized Chamfer Distance",
                 hue="Method",
                 errwidth=1,
                 ci="sd",
                 data=df)
# ax.set_ylim([0, 5])
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)

change_width(ax, .04)

plt.suptitle("ShapeNet novel categories")
plt.savefig("shapenet_40_categories_novel.pdf",
            dpi=dpi,
            bbox_inches="tight",
            palette="colorblind",
            pad_inches=0.05)
