import argparse
import numpy as np
import pandas as pd

from utilities import utils
import torch
import model.networks.basics.workspace as ws
import os

import uncertainty_toolbox as uct

def flatten_dicts(dict_eval):
    dict_flattenned = {}
    for i_k, dict_i_v in dict_eval.items():
        for i_name, i_val in dict_i_v.items():
            if isinstance(i_val, float):
                dict_flattenned[i_k + '_' + i_name] = np.round(i_val, 4)
    return dict_flattenned


def quantitative_evaluation_airway(network: torch.nn.Module,
                           train_csa_dataset: torch.utils.data.Dataset,
                           savedir: str):

    # specs = ws.load_experiment_specifications(specs_filename)
    # root_path = os.path.join(specs['LoggingRoot'], specs['ExperimentName'])
    # savedir = os.path.join(root_path, ws.la_vis_comp_subdir)
    utils.cond_mkdir(savedir)

    slt_percentiles = [0, 20, 40, 50, 60, 80, 100]
    list_pos = np.percentile(train_csa_dataset.train_valid_pos, slt_percentiles)

    list_summary = []
    for ith_pos in range(len(list_pos)):
        model_input, arr_gt_csa = utils.make_input_for_vis_comp_airway(train_csa_dataset, list_pos[ith_pos],  network.device)
        #model_input = utils.movedict2cuda(model_input, network.model.device)
        f_mu, f_var = network.infer_mu_and_var_testing(model_input)
        # print(f_var.max())
        f_mu = f_mu.squeeze().detach().cpu().numpy()
        pred_std = f_var.squeeze().detach().sqrt().cpu().numpy()
        arr_gt_csa = arr_gt_csa.cpu().numpy()
        # Compute all uncertainty metrics
        metrics = uct.metrics.get_all_metrics(f_mu, pred_std, arr_gt_csa)
        dict_flattened = {}
        dict_flattened['pos'] = np.round(list_pos[ith_pos], 4)
        dict_flattened.update(flatten_dicts(metrics))
        list_summary.append(dict_flattened)
    savepath = os.path.join(savedir, 'summary.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)
    return list_summary



def quantitative_evaluation(dataset_name):
    if dataset_name == "Airway":
        return quantitative_evaluation_airway











# if __name__ == "__main__":
#     arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
#     arg_parser.add_argument(
#         "--experiment",
#         "-e",
#         dest="experiment_directory",
#         default='../config/airways/lucid_1d_csa_v1_lip.json',
#         help="The experiment directory. This directory should include "
#              + "experiment specifications in 'specs.json', and logging will be "
#              + "done in this directory as well.",
#     )
#     arg_parser.add_argument(
#         "--model_ckpt",
#         "-mc",
#         dest="model_ckpt",
#         default="latest",
#         help="The model checkpoint weights to use. This can be a number indicated an epoch "
#         + "or 'latest' for the latest weights (this is the default)",
#     )
#     arg_parser.add_argument(
#         "--la_ckpt",
#         "-lc",
#         dest="la_ckpt",
#         default="state_dict",
#         help="The LA checkpoint weights to use.",
#     )
#
#     args = arg_parser.parse_args()
#
#     la, csa_dataset, csa_dataloader = load_model_and_data(
#         specs_filename=args.experiment_directory,
#         filename_model_ckpt=args.model_ckpt,
#         filename_LA_ckpt=args.la_ckpt
#         )
#     quantitative_evaluation(la, csa_dataset, args.experiment_directory)
#     print('1')