import json
import math
from collections import defaultdict
from decimal import ROUND_HALF_UP, Decimal
from pathlib import Path

import matplotlib.pylab as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tabulate import tabulate


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    args = None
    if args_path.is_file():
        with open(args_path) as f:
            args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


proposed_lo_0_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/asymmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_0/50_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/asymmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_0/100_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/asymmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_0/300_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/asymmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_0/1000_samples/lr_0.005/1d7bd0ff765b/",
]

proposed_lo_800_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/asymmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_800/50_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/asymmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_800/100_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/asymmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_800/300_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/asymmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_800/1000_samples/lr_0.005/1d7bd0ff765b/",
]

gauss_densification_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/icl_nuim/random_baseline/asymmetric/variance_0.0005/50_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/icl_nuim/random_baseline/asymmetric/variance_0.0005/100_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/icl_nuim/random_baseline/asymmetric/variance_0.0005/300_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/icl_nuim/random_baseline/asymmetric/variance_0.0005/1000_inputs",
]

pcn_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0531/asymmetric/icl_nuim/epoch_5000/50_inputs/b860c95e5203",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0531/asymmetric/icl_nuim/epoch_5000/100_inputs/b860c95e5203",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0531/asymmetric/icl_nuim/epoch_5000/300_inputs/b860c95e5203",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0531/asymmetric/icl_nuim/epoch_5000/1000_inputs/b860c95e5203",
]

occnet_directories = [
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0531/asymmetric/icl_nuim/epoch_3000/50_inputs/a801cb8e67ab/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0531/asymmetric/icl_nuim/epoch_3000/100_inputs/a801cb8e67ab/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0531/asymmetric/icl_nuim/epoch_3000/300_inputs/a801cb8e67ab/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0531/asymmetric/icl_nuim/epoch_3000/1000_inputs/a801cb8e67ab/",
]

igr_directories = [
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0531/uniform_sparse_sampling/asymmetric/cd_30000/idl_nuim/epoch_5000/latent_optim_800/50_context/lr_0.05/cbbaf94dc7d3/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0531/uniform_sparse_sampling/asymmetric/cd_30000/idl_nuim/epoch_5000/latent_optim_800/100_context/lr_0.05/cbbaf94dc7d3/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0531/uniform_sparse_sampling/asymmetric/cd_30000/idl_nuim/epoch_5000/latent_optim_800/300_context/lr_0.05/cbbaf94dc7d3/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0531/uniform_sparse_sampling/asymmetric/cd_30000/idl_nuim/epoch_5000/latent_optim_800/1000_context/lr_0.05/cbbaf94dc7d3/",
]

deepsdf_directories = [
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0531/uniform_sparse_sampling/asymmetric/cd_30000/icl_nuim/epoch_3000/latent_optim_800/50_context/lr_0.05/6f26a80fb1fb/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0531/uniform_sparse_sampling/asymmetric/cd_30000/icl_nuim/epoch_3000/latent_optim_800/100_context/lr_0.05/6f26a80fb1fb/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0531/uniform_sparse_sampling/asymmetric/cd_30000/icl_nuim/epoch_3000/latent_optim_800/300_context/lr_0.05/6f26a80fb1fb/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0531/uniform_sparse_sampling/asymmetric/cd_30000/icl_nuim/epoch_3000/latent_optim_800/1000_context/lr_0.05/6f26a80fb1fb/",
]

map_method_chamfer_distance = {}
method_name_list = [
    "gauss_densification",
    "pcn",
    "occnet",
    "deepsdf",
    "igr",
    "proposed_lo_0",
    "proposed_lo_800",
]
for method_name in method_name_list:
    map_method_chamfer_distance[method_name] = {}
model_id_set = set()

for result_directory in proposed_lo_0_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_0"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_0"][num_context][
            model_id] = chamfer_distance
        model_id_set.add(model_id)

for result_directory in proposed_lo_800_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_800"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_800"][num_context][
            model_id] = chamfer_distance

for result_directory in pcn_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["pcn"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["pcn"][num_context][
            model_id] = chamfer_distance

for result_directory in gauss_densification_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["gauss_densification"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["gauss_densification"][num_context][
            model_id] = chamfer_distance

for result_directory in occnet_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["occnet"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["occnet"][num_context][
            model_id] = chamfer_distance

for result_directory in igr_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["igr"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["igr"][num_context][
            model_id] = chamfer_distance

for result_directory in deepsdf_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["deepsdf"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["deepsdf"][num_context][
            model_id] = chamfer_distance

map_method_name_label = {
    "proposed_lo_0": "Proposed Method w/o opt",
    "proposed_lo_800": "Proposed Method w/ opt",
    "gauss_densification": "Gauss Densification",
    "pcn": "PCN",
    "occnet": "OccNet",
    "igr": "IGR",
    "deepsdf": "DeepSDF",
}
data_list = []
base_method_name = "proposed_lo_0"
for model_id in model_id_set:
    for num_context in [50, 100, 300, 1000]:
        base_chamfer_distance = map_method_chamfer_distance[base_method_name][
            num_context][model_id]
        if base_chamfer_distance is None:
            print(model_id)
            continue
        for method_name in method_name_list:
            if model_id not in map_method_chamfer_distance[method_name][
                    num_context]:
                print(method_name, model_id)
                continue
            chamfer_distance = map_method_chamfer_distance[method_name][
                num_context][model_id]
            if chamfer_distance is None:
                print(method_name, model_id)
                continue
            relative_cd = chamfer_distance / base_chamfer_distance
            data_list.append({
                "Chamfer Distance":
                chamfer_distance * 1000,
                "Log Chamfer Distance +10":
                math.log(chamfer_distance) + 10,
                "Log Normalized Chamfer Distance +1":
                math.log(relative_cd) + 1,
                "Method":
                map_method_name_label[method_name],
                "model_id":
                model_id,
                "Number of Context":
                num_context,
            })

df = pd.DataFrame(data_list)


def change_width(ax, new_value):
    for patch in ax.patches:
        current_width = patch.get_width()
        diff = current_width - new_value

        # we change the bar width
        patch.set_width(new_value)

        # we recenter the bar
        patch.set_x(patch.get_x() + diff * .5)


figsize_px = np.array([2000, 300])
dpi = 100
figsize_inch = figsize_px / dpi
plt.figure(figsize=figsize_inch)
plt.subplot(1, 2, 1)
ax = sns.barplot(x="Number of Context",
                 y="Log Chamfer Distance +10",
                 hue="Method",
                 errwidth=1,
                 ci="sd",
                 data=df)
# ax.set_ylim([0, 5])
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)

change_width(ax, .04)

plt.subplot(1, 2, 2)
ax = sns.barplot(x="Number of Context",
                 y="Log Normalized Chamfer Distance +1",
                 hue="Method",
                 errwidth=1,
                 ci="sd",
                 data=df)
# ax.set_ylim([0, 5])
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)

change_width(ax, .04)

plt.suptitle("ICL-NUIM")
plt.savefig("icl_nuim.pdf",
            dpi=dpi,
            bbox_inches="tight",
            palette="colorblind",
            pad_inches=0.05)
