from pathlib import Path
from collections import defaultdict
import json
import pandas as pd
from tabulate import tabulate
from decimal import Decimal, ROUND_HALF_UP


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    args = None
    if args_path.is_file():
        with open(args_path) as f:
            args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


proposed_lo_0_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/symmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_0/50_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/symmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_0/100_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/symmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_0/300_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/symmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_0/1000_samples/lr_0.005/1d7bd0ff765b/",
]

proposed_lo_800_directories = [
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/symmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_800/50_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/symmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_800/100_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/symmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_800/300_samples/lr_0.005/1d7bd0ff765b/",
    "/mnt/hdd/mnt/hdd/neurips/meta_learning_sdf/surface_cd/uniform_sparse_sampling/baseline/0531/symmetric/cd_30000/icl_nuim/epoch_5000/latent_optim_800/1000_samples/lr_0.005/1d7bd0ff765b/",
]

gauss_densification_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/icl_nuim/random_baseline/symmetric/variance_0.0005/1000_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/icl_nuim/random_baseline/symmetric/variance_0.0005/1000_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/icl_nuim/random_baseline/symmetric/variance_0.0005/1000_inputs",
    "/mnt/hdd/mnt/hdd/neurips/pcn/icl_nuim/random_baseline/symmetric/variance_0.0005/1000_inputs",
]

pcn_directories = [
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0531/symmetric/icl_nuim/epoch_5000/50_inputs/b860c95e5203",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0531/symmetric/icl_nuim/epoch_5000/100_inputs/b860c95e5203",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0531/symmetric/icl_nuim/epoch_5000/300_inputs/b860c95e5203",
    "/mnt/hdd/mnt/hdd/neurips/pcn/uniform_sparse_sampling/0531/symmetric/icl_nuim/epoch_5000/1000_inputs/b860c95e5203",
]

occnet_directories = [
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0531/symmetric/icl_nuim/epoch_3000/50_inputs/a801cb8e67ab/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0531/symmetric/icl_nuim/epoch_3000/100_inputs/a801cb8e67ab/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0531/symmetric/icl_nuim/epoch_3000/300_inputs/a801cb8e67ab/",
    "/mnt/hdd/mnt/hdd/neurips/occupancy_networks/uniform_sparse_sampling/0531/symmetric/icl_nuim/epoch_3000/1000_inputs/a801cb8e67ab/",
]

igr_directories = [
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0531/uniform_sparse_sampling/symmetric/cd_30000/idl_nuim/epoch_5000/latent_optim_800/50_context/lr_0.05/cbbaf94dc7d3/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0531/uniform_sparse_sampling/symmetric/cd_30000/idl_nuim/epoch_5000/latent_optim_800/100_context/lr_0.05/cbbaf94dc7d3/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0531/uniform_sparse_sampling/symmetric/cd_30000/idl_nuim/epoch_5000/latent_optim_800/300_context/lr_0.05/cbbaf94dc7d3/",
    "/mnt/hdd/mnt/hdd/neurips/implicit-geometry-regularization/0531/uniform_sparse_sampling/symmetric/cd_30000/idl_nuim/epoch_5000/latent_optim_800/1000_context/lr_0.05/cbbaf94dc7d3/",
]

deepsdf_directories = [
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0531/uniform_sparse_sampling/symmetric/cd_30000/icl_nuim/epoch_3000/latent_optim_800/50_context/lr_0.05/6f26a80fb1fb/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0531/uniform_sparse_sampling/symmetric/cd_30000/icl_nuim/epoch_3000/latent_optim_800/100_context/lr_0.05/6f26a80fb1fb/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0531/uniform_sparse_sampling/symmetric/cd_30000/icl_nuim/epoch_3000/latent_optim_800/300_context/lr_0.05/6f26a80fb1fb/",
    "/mnt/hdd/mnt/hdd/neurips/deep_sdf/0531/uniform_sparse_sampling/symmetric/cd_30000/icl_nuim/epoch_3000/latent_optim_800/1000_context/lr_0.05/6f26a80fb1fb/",
]

map_method_chamfer_distance = {}
method_name_list = [
    "gauss_densification",
    "pcn",
    "occnet",
    "deepsdf",
    "igr",
    "proposed_lo_0",
    "proposed_lo_800",
]
for method_name in method_name_list:
    map_method_chamfer_distance[method_name] = {}
model_id_set = set()

for result_directory in proposed_lo_0_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_0"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_0"][num_context][
            model_id] = chamfer_distance
        model_id_set.add(model_id)

for result_directory in proposed_lo_800_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["proposed_lo_800"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["proposed_lo_800"][num_context][
            model_id] = chamfer_distance

for result_directory in pcn_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["pcn"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["pcn"][num_context][
            model_id] = chamfer_distance

for result_directory in gauss_densification_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["gauss_densification"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["gauss_densification"][num_context][
            model_id] = chamfer_distance

for result_directory in occnet_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["num_input_points"][0]
    map_method_chamfer_distance["occnet"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["occnet"][num_context][
            model_id] = chamfer_distance

for result_directory in igr_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["igr"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["igr"][num_context][
            model_id] = chamfer_distance

for result_directory in deepsdf_directories:
    result, args, num_data = _summarize(result_directory)
    print(num_data)
    num_context = result["latent_optimization_num_samples"][0]
    map_method_chamfer_distance["deepsdf"][num_context] = {}
    for model_id, chamfer_distance in zip(result["model_id"],
                                          result["chamfer_distance"]):
        map_method_chamfer_distance["deepsdf"][num_context][
            model_id] = chamfer_distance

data_list = []
base_method_name = "proposed_lo_0"
for model_id in model_id_set:
    for num_context in [1000]:
        base_chamfer_distance = map_method_chamfer_distance[base_method_name][
            num_context][model_id]
        if base_chamfer_distance is None:
            print(model_id)
            continue
        for method_name in method_name_list:
            if model_id not in map_method_chamfer_distance[method_name][
                    num_context]:
                print(method_name, model_id)
                continue
            chamfer_distance = map_method_chamfer_distance[method_name][
                num_context][model_id]
            if chamfer_distance is None:
                print(method_name, model_id)
                continue
            relative_cd = chamfer_distance / base_chamfer_distance
            data_list.append({
                "chamfer_distance": chamfer_distance * 1000,
                "relative_cd": relative_cd,
                "method_name": method_name,
                "model_id": model_id,
                "num_context": num_context,
            })

df = pd.DataFrame(data_list)

grouped = df.groupby(["num_context", "method_name"])
mean = grouped.mean()
std = grouped.std()
print(mean)
print(std)
print(mean.columns)
print(mean.axes)

print(mean.to_dict())

mean_dict = mean.to_dict()
std_dict = std.to_dict()

rows = []
for num_context in [1000]:
    row = [num_context]
    for method_name in method_name_list:
        cd_mean = mean_dict["chamfer_distance"][(num_context, method_name)]
        cd_std = std_dict["chamfer_distance"][(num_context, method_name)]
        cd_mean = Decimal(cd_mean).quantize(Decimal("0.001"), ROUND_HALF_UP)
        cd_std = Decimal(cd_std).quantize(Decimal("0.001"), ROUND_HALF_UP)
        row.append(f"{cd_mean:.03f} (±{cd_std:.03f})")
    rows.append(row)

print(
    tabulate(rows,
             headers=[
                 "num_context",
                 "Proposed Method w/o opt",
                 "Proposed Method w/ opt",
                 "Gauss Densification",
                 "PCN",
                 "OccNet",
                 "IGR",
                 "DeepSDF",
             ],
             tablefmt="github"))
print()

rows = []
for num_context in [1000]:
    row = [num_context]
    for method_name in method_name_list:
        cd_mean = mean_dict["relative_cd"][(num_context, method_name)]
        cd_std = std_dict["relative_cd"][(num_context, method_name)]
        cd_mean = Decimal(cd_mean).quantize(Decimal("0.001"), ROUND_HALF_UP)
        cd_std = Decimal(cd_std).quantize(Decimal("0.001"), ROUND_HALF_UP)
        row.append(f"{cd_mean:.03f} (±{cd_std:.03f})")
    rows.append(row)

print(
    tabulate(rows,
             headers=[
                 "num_context",
                 "Gauss Densification",
                 "PCN",
                 "OccNet",
                 "DeepSDF",
                 "IGR",
                 "Proposed Method w/o opt",
                 "Proposed Method w/ opt",
             ],
             tablefmt="github"))

for method_name in method_name_list:
    print(method_name)
    for num_context in [1000]:
        cd_mean = mean_dict["chamfer_distance"][(num_context, method_name)]
        cd_std = std_dict["chamfer_distance"][(num_context, method_name)]
        cd_mean = Decimal(cd_mean).quantize(Decimal("0.001"), ROUND_HALF_UP)
        cd_std = Decimal(cd_std).quantize(Decimal("0.001"), ROUND_HALF_UP)
        print(f"{cd_mean:.03f} $\pm${cd_std:.03f} &")

    for num_context in [1000]:
        cd_mean = mean_dict["relative_cd"][(num_context, method_name)]
        cd_std = std_dict["relative_cd"][(num_context, method_name)]
        cd_mean = Decimal(cd_mean).quantize(Decimal("0.001"), ROUND_HALF_UP)
        cd_std = Decimal(cd_std).quantize(Decimal("0.001"), ROUND_HALF_UP)
        print(f"{cd_mean:.03f} $\pm${cd_std:.03f} &")

    print()
