from pathlib import Path
from collections import defaultdict
import json
from decimal import Decimal, ROUND_HALF_UP
import pandas as pd
from tabulate import tabulate
import math


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    args = None
    if args_path.is_file():
        with open(args_path) as f:
            args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    assert len(metrics_path_list) > 0, result_directory
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


proposed_lo_0_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_0/50_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_0/100_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_0/300_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_0/1000_samples/lr_0.005/ec27d70e7cf0/",
]

proposed_lo_800_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/50_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/100_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/300_samples/lr_0.005/ec27d70e7cf0/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0510/symmetric/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/1000_samples/lr_0.005/ec27d70e7cf0/",
]

gauss_densification_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories_novel/50_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories_novel/100_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories_novel/300_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/random_baseline/symmetric/variance_5e-05/40_categories_novel/1000_inputs",
]

pcn_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories_novel/epoch_5000/50_inputs/81d2edd77142",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories_novel/epoch_5000/100_inputs/81d2edd77142",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories_novel/epoch_5000/300_inputs/81d2edd77142",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0525/symmetric/40_categories_novel/epoch_5000/1000_inputs/81d2edd77142",
]

occnet_directories = [
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories_novel/epoch_5000/50_inputs/b6ae866d88ba/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories_novel/epoch_5000/100_inputs/b6ae866d88ba/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories_novel/epoch_5000/300_inputs/b6ae866d88ba/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0517/symmetric/40_categories_novel/epoch_5000/1000_inputs/b6ae866d88ba/",
]

igr_directories = [
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/50_context/lr_0.05/eb771fe8d137/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/100_context/lr_0.05/eb771fe8d137/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/300_context/lr_0.05/eb771fe8d137/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0510/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/1000_context/lr_0.05/eb771fe8d137/",
]

deepsdf_directories = [
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/50_context/lr_0.05/fd4c9568e7cc/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/100_context/lr_0.05/fd4c9568e7cc/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/300_context/lr_0.05/fd4c9568e7cc/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0510/learning_shape_space/uniform_sparse_sampling/cd_30000/40_categories_novel/epoch_5000/latent_optim_800/1000_context/lr_0.05/fd4c9568e7cc/",
]

map_method_chamfer_distance = {}
method_name_list = [
    "gauss_densification",
    "pcn",
    "occnet",
    "deepsdf",
    "igr",
    "proposed_lo_0",
    "proposed_lo_800",
]
for method_name in method_name_list:
    map_method_chamfer_distance[method_name] = {}
model_id_set = set()

for result_directory in proposed_lo_0_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_0"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_0"][num_context][
            model_id] = chamfer_distance
        model_id_set.add(model_id)

for result_directory in proposed_lo_800_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_800"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_800"][num_context][
            model_id] = chamfer_distance

for result_directory in pcn_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["pcn"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["pcn"][num_context][
            model_id] = chamfer_distance

for result_directory in gauss_densification_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["gauss_densification"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["gauss_densification"][num_context][
            model_id] = chamfer_distance

for result_directory in occnet_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["occnet"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["occnet"][num_context][
            model_id] = chamfer_distance

for result_directory in igr_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["igr"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["igr"][num_context][
            model_id] = chamfer_distance

for result_directory in deepsdf_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["deepsdf"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["deepsdf"][num_context][
            model_id] = chamfer_distance

data_list = []
base_method_name = "proposed_lo_0"
for model_id in model_id_set:
    for num_context in [50, 100, 300, 1000]:
        base_chamfer_distance = map_method_chamfer_distance[base_method_name][
            num_context][model_id]
        if base_chamfer_distance is None:
            print(model_id)
            continue
        for method_name in method_name_list:
            if model_id not in map_method_chamfer_distance[method_name][
                    num_context]:
                print(method_name, model_id)
                continue
            chamfer_distance = map_method_chamfer_distance[method_name][
                num_context][model_id]
            if chamfer_distance is None:
                print(method_name, model_id)
                continue
            relative_cd = chamfer_distance / base_chamfer_distance
            data_list.append({
                "chamfer_distance": chamfer_distance * 1000,
                "relative_cd": relative_cd,
                "method_name": method_name,
                "model_id": model_id,
                "num_context": num_context,
            })

df = pd.DataFrame(data_list)

grouped = df.groupby(["num_context", "method_name"])
mean = grouped.mean()
std = grouped.std()
print(mean)
print(std)
print(mean.columns)
print(mean.axes)

print(mean.to_dict())

mean_dict = mean.to_dict()
std_dict = std.to_dict()

rows = []
for num_context in [50, 100, 300, 1000]:
    row = [num_context]
    for method_name in method_name_list:
        cd_mean = mean_dict["chamfer_distance"][(num_context, method_name)]
        cd_std = std_dict["chamfer_distance"][(num_context, method_name)]
        cd_mean = Decimal(cd_mean).quantize(Decimal("0.01"), ROUND_HALF_UP)
        cd_std = Decimal(cd_std).quantize(Decimal("0.01"), ROUND_HALF_UP)
        row.append(f"{cd_mean:.02f} (±{cd_std:.02f})")
    rows.append(row)

print(
    tabulate(rows,
             headers=[
                 "num_context",
                 "Proposed Method w/o opt",
                 "Proposed Method w/ opt",
                 "Gauss Densification",
                 "PCN",
                 "OccNet",
                 "IGR",
                 "DeepSDF",
             ],
             tablefmt="github"))
print()

rows = []
for num_context in [50, 100, 300, 1000]:
    row = [num_context]
    for method_name in method_name_list:
        cd_mean = mean_dict["relative_cd"][(num_context, method_name)]
        cd_std = std_dict["relative_cd"][(num_context, method_name)]
        cd_mean = Decimal(cd_mean).quantize(Decimal("0.01"), ROUND_HALF_UP)
        cd_std = Decimal(cd_std).quantize(Decimal("0.01"), ROUND_HALF_UP)
        row.append(f"{cd_mean:.02f} (±{cd_std:.02f})")
    rows.append(row)

print(
    tabulate(rows,
             headers=[
                 "num_context",
                 "Proposed Method w/o opt",
                 "Proposed Method w/ opt",
                 "Gauss Densification",
                 "PCN",
                 "OccNet",
                 "IGR",
                 "DeepSDF",
             ],
             tablefmt="github"))

for method_name in method_name_list:
    print(method_name)
    for num_context in [50, 100, 300, 1000]:
        cd_mean = mean_dict["chamfer_distance"][(num_context, method_name)]
        cd_std = std_dict["chamfer_distance"][(num_context, method_name)]
        cd_mean = Decimal(cd_mean).quantize(Decimal("0.01"), ROUND_HALF_UP)
        cd_std = Decimal(cd_std).quantize(Decimal("0.01"), ROUND_HALF_UP)
        print(f"{cd_mean:.02f} $\pm${cd_std:.02f} &")

    for num_context in [50, 100, 300, 1000]:
        cd_mean = mean_dict["relative_cd"][(num_context, method_name)]
        cd_std = std_dict["relative_cd"][(num_context, method_name)]
        cd_mean = Decimal(cd_mean).quantize(Decimal("0.01"), ROUND_HALF_UP)
        cd_std = Decimal(cd_std).quantize(Decimal("0.01"), ROUND_HALF_UP)
        print(f"{cd_mean:.02f} $\pm${cd_std:.02f} &")

    print()
