import json
from collections import defaultdict
from decimal import ROUND_HALF_UP, Decimal
from pathlib import Path

import pandas as pd
from tabulate import tabulate

_map_category_name = {
    "02691156": "Plane",
    "02747177": "Garbage Can",
    "02773838": "Bag",
    "02801938": "Basket",
    "02808440": "Bathtub",
    "02818832": "Bed",
    "02828884": "Bench",
    "02834778": "Bicycle",
    "02843684": "Birdhouse",
    "02871439": "Bookshelf",
    "02876657": "Bottle",
    "02880940": "Bowl",
    "02924116": "Bus",
    "02933112": "Cabinet",
    "02942699": "Camera",
    "02946921": "Can",
    "02954340": "Cap",
    "02958343": "Car",
    "02992529": "Cellular Phone",
    "03001627": "Chair",
    "03046257": "Clock",
    "03085013": "Keyboard",
    "03207941": "Dishwasher",
    "03211117": "Display",
    "03261776": "Earphone",
    "03325088": "Faucet",
    "03337140": "File Cabinet",
    "03467517": "Guitar",
    "03513137": "Helmet",
    "03593526": "Jar",
    "03624134": "Knife",
    "03636649": "Lamp",
    "03642806": "Laptop",
    "03691459": "Speaker",
    "03710193": "Mailbox",
    "03759954": "Microphone",
    "03761084": "Microwave",
    "03790512": "Motorcycle",
    "03797390": "Mug",
    "03928116": "Piano",
    "03938244": "Pillow",
    "03948459": "Pistol",
    "03991062": "Pot",
    "04004475": "Printer",
    "04074963": "Remote Control",
    "04090263": "Rifle",
    "04099429": "Rocket",
    "04225987": "Skateboard",
    "04256520": "Sofa",
    "04330267": "Stove",
    "04379243": "Table",
    "04401088": "Telephone",
    "04460130": "Tower",
    "04468005": "Train",
    "04530566": "Vessel",
    "04554684": "Washer",
    "04591713": "Wine Bottle"
}


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    args = None
    if args_path.is_file():
        with open(args_path) as f:
            args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    assert len(metrics_path_list) > 0, result_directory
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


proposed_lo_0_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories/epoch_5000/latent_optim_0/50_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories/epoch_5000/latent_optim_0/100_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories/epoch_5000/latent_optim_0/300_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories/epoch_5000/latent_optim_0/1000_samples/lr_0.005/ec27d70e7cf0/",
]

proposed_lo_800_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories/epoch_5000/latent_optim_800/50_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories/epoch_5000/latent_optim_800/100_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories/epoch_5000/latent_optim_800/300_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories/epoch_5000/latent_optim_800/1000_samples/lr_0.005/ec27d70e7cf0/",
]

gauss_densification_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories/50_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories/100_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories/300_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories/1000_inputs",
]

pcn_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories/epoch_5000/50_inputs/81d2edd77142",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories/epoch_5000/100_inputs/81d2edd77142",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories/epoch_5000/300_inputs/81d2edd77142",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories/epoch_5000/1000_inputs/81d2edd77142",
]

occnet_directories = [
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories/epoch_5000/50_inputs/b6ae866d88ba/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories/epoch_5000/100_inputs/b6ae866d88ba/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories/epoch_5000/300_inputs/b6ae866d88ba/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories/epoch_5000/1000_inputs/b6ae866d88ba/",
]

igr_directories = [
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories/epoch_5000/latent_optim_800/50_context/lr_0.05/eb771fe8d137/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories/epoch_5000/latent_optim_800/100_context/lr_0.05/eb771fe8d137/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories/epoch_5000/latent_optim_800/300_context/lr_0.05/eb771fe8d137/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories/epoch_5000/latent_optim_800/1000_context/lr_0.05/eb771fe8d137/",
]

deepsdf_directories = [
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories/epoch_5000/latent_optim_800/50_context/lr_0.05/fd4c9568e7cc/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories/epoch_5000/latent_optim_800/100_context/lr_0.05/fd4c9568e7cc/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories/epoch_5000/latent_optim_800/300_context/lr_0.05/fd4c9568e7cc/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories/epoch_5000/latent_optim_800/1000_context/lr_0.05/fd4c9568e7cc/",
]

map_method_chamfer_distance = {}
method_name_list = [
    "gauss_densification",
    "pcn",
    "occnet",
    "deepsdf",
    "igr",
    "proposed_lo_0",
    "proposed_lo_800",
]

for method_name in method_name_list:
    map_method_chamfer_distance[method_name] = {}
model_id_set = set()

for result_directory in proposed_lo_0_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_0"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_0"][num_context][
            model_id] = chamfer_distance
        model_id_set.add(model_id)

for result_directory in proposed_lo_800_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_800"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_800"][num_context][
            model_id] = chamfer_distance

for result_directory in gauss_densification_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["gauss_densification"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["gauss_densification"][num_context][
            model_id] = chamfer_distance

for result_directory in pcn_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["pcn"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["pcn"][num_context][
            model_id] = chamfer_distance

for result_directory in occnet_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["occnet"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["occnet"][num_context][
            model_id] = chamfer_distance

for result_directory in igr_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["igr"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["igr"][num_context][
            model_id] = chamfer_distance

for result_directory in deepsdf_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["deepsdf"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["deepsdf"][num_context][
            model_id] = chamfer_distance

data_list = []
base_method_name = "proposed_lo_0"
for model_id in model_id_set:
    for num_context in [50, 100, 300, 1000]:
        base_chamfer_distance = map_method_chamfer_distance[base_method_name][
            num_context][model_id]
        if base_chamfer_distance is None:
            print(model_id)
            continue
        for method_name in method_name_list:
            if model_id not in map_method_chamfer_distance[method_name][
                    num_context]:
                print(method_name, model_id)
                continue
            chamfer_distance = map_method_chamfer_distance[method_name][
                num_context][model_id]
            if chamfer_distance is None:
                print(method_name, model_id)
                continue
            category_id = model_id.split("_")[0]
            object_id = model_id.split("_")[1]
            relative_cd = chamfer_distance / base_chamfer_distance
            data_list.append({
                "chamfer_distance": chamfer_distance * 1000,
                "relative_cd": relative_cd,
                "method_name": method_name,
                "model_id": model_id,
                "category_id": category_id,
                "num_context": num_context,
            })

df = pd.DataFrame(data_list)

grouped = df.groupby(["num_context", "method_name", "category_id"])
mean = grouped.mean()
std = grouped.std()
# print(mean)
# print(std)
# print(mean.columns)
# print(mean.axes)

# print(mean.to_dict())

map_method_name_label = {
    "proposed_lo_0": "Proposed Method w/o opt",
    "proposed_lo_800": "Proposed Method w/ opt",
    "gauss_densification": "Gauss Densification",
    "pcn": "PCN",
    "occnet": "OccNet",
    "igr": "IGR",
    "deepsdf": "DeepSDF",
}

mean_dict = mean.to_dict()
std_dict = std.to_dict()
"""
\multicolumn{1}{|l|}{\multirow{4}{*}{Bag}}      
& 50 
& 4.113 (±2.486) & 5.113 (±2.486) & 3.085 (±2.123)  & 6.219 (±7.010) 
& 3.465 (±1.491) & 2.479 (±2.385) & 1.223 (±0.855) \\ 
\cline{2-9} 
\multicolumn{1}{|l|}{}                          & 100 
& 4.113 (±2.486) & 5.113 (±2.486) & 3.085 (±2.123)  & 6.219 (±7.010) 
& 3.465 (±1.491) & 2.479 (±2.385) & 1.223 (±0.855) \\ 
\cline{2-9} 
\multicolumn{1}{|l|}{}                          & 300 
& 4.113 (±2.486) & 5.113 (±2.486) & 3.085 (±2.123)  & 6.219 (±7.010) 
& 3.465 (±1.491) & 2.479 (±2.385) & 1.223 (±0.855) \\ 
\cline{2-9} 
\multicolumn{1}{|l|}{}                          & 1000 
& 4.113 (±2.486) & 5.113 (±2.486) & 3.085 (±2.123)  & 6.219 (±7.010) 
& 3.465 (±1.491) & 2.479 (±2.385) & 1.223 (±0.855) \\ 
\cline{2-9} \hline
"""

tex_results = []
for category_id in _map_category_name.keys():
    try:
        tex_str = ""
        category_label = _map_category_name[category_id]
        for num_context in [50, 100, 300, 1000]:
            if num_context == 50:
                tex_str += f"\\multicolumn{{1}}{{|l|}}{{\\multirow{{4}}{{*}}{{{category_label}}}}}\n"
            else:
                tex_str += "\\multicolumn{1}{|l|}{}\n"
            tex_str += f"& {num_context}\n"

            # sort
            cd_mean_ranking = []
            for method_name in method_name_list:
                row = [map_method_name_label[method_name]]
                key = (num_context, method_name, category_id)
                if key not in mean_dict["chamfer_distance"]:
                    raise KeyError()
                cd_mean = mean_dict["chamfer_distance"][key]
                cd_mean_ranking.append((cd_mean, method_name))
            cd_mean_ranking = sorted(cd_mean_ranking, key=lambda tup: tup[0])

            def _bold_if_needed(str, method_name):
                if cd_mean_ranking[0][1] == method_name:
                    return f"\\bf{{{str}}}"
                return str

            for method_name in method_name_list:
                row = [map_method_name_label[method_name]]
                key = (num_context, method_name, category_id)
                cd_mean = mean_dict["chamfer_distance"][key]
                cd_std = std_dict["chamfer_distance"][key]
                cd_mean = Decimal(cd_mean).quantize(Decimal("0.01"),
                                                    ROUND_HALF_UP)
                cd_std = Decimal(cd_std).quantize(Decimal("0.01"),
                                                  ROUND_HALF_UP)
                tex_str += "& " + _bold_if_needed(f"{cd_mean} $\\pm$ {cd_std}",
                                                  method_name) + "\n"
            tex_str += "\\\\\n\\cline{2-9}\n"
        tex_str += "\\hline\n"
        tex_results.append(tex_str)
    except KeyError:
        pass

with open("tex_cd.txt", "w") as f:
    f.write("\n".join(tex_results))

# for num_context in [50, 100, 300, 1000]:
#     rows = []
#     for method_name in method_name_list
#         row = [map_method_label[method_name]]
#         headers = ["Method"]
#         for category_id in _map_category_name.keys():
#             key = (num_context, method_name, category_id)
#             if key not in mean_dict["relative_cd"]:
#                 continue
#             cd_mean = mean_dict["relative_cd"][key]
#             cd_std = std_dict["relative_cd"][key]
#             headers.append(_map_category_name[category_id])
#             row.append(f"{cd_mean:.03f} (±{cd_std:.03f})")
#         rows.append(row)

#     print(f"# context = {num_context}")
#     print(tabulate(rows, headers=headers, tablefmt="github"))
#     print("")
