#
# import matplotlib.pyplot as plt
# import pandas as pd
#
# plt.rcParams['figure.dpi']= 300
# import torch
# import numpy as np
# import vis_utils
# from utilities import utils
# import os
# import argparse
# from pipeline.load import *
# import torch.utils.data as data_utils
# from utilities.eval_quant import quantitative_evaluation
# from utilities.utils import record_prediction, denormalize, denormlize_ds, denormalize_from_distribution
# import uncertainty_toolbox as uct
# from functools import partial
#
#
# def flatten_dicts(dict_eval):
#     dict_flattenned = {}
#     for i_k, dict_i_v in dict_eval.items():
#         for i_name, i_val in dict_i_v.items():
#             if isinstance(i_val, float):
#                 dict_flattenned[i_k + '_' + i_name] = np.round(i_val, 4)
#     return dict_flattenned
#
#
#
# def test_(network: torch.nn.Module,
#         test_csa_dataset: torch.utils.data.Dataset,
#         savedir: str,
#           which_set: str):
#
#     utils.cond_mkdir(savedir)
#
#     model_input, arr_gt_csa, df_testset = \
#         utils.make_input_for_eval(test_csa_dataset.DATASETNANE)(scalar_dataset=test_csa_dataset, device=network.device)
#
#     #f_mu, f_var = network.infer_mu_and_var_testing(model_input)
#
#     with torch.no_grad():
#         ## query the model
#         f_mu, f_var = utils.batched_predict(model_trained=network,
#                                           arr_input_grids=model_input,
#                                           batch_size=1000)
#
#
#     f_mu, f_std = f_mu, torch.sqrt(f_var)
#
#     f_mu_ori, pred_std_ori, low_bd_map, high_bd_map = denormalize_from_distribution(ds_=test_csa_dataset, mu=f_mu, sigma=f_std)
#
#
#     # f_mu_ori = denormalize(ds_=test_csa_dataset,
#     #                    arr_=f_mu.squeeze().detach().cpu().numpy(),
#     #                    var_name=test_csa_dataset.tgt_var_name)
#
#     # pred_std_ori = denormalize(ds_=test_csa_dataset,
#     #                        arr_=f_std.squeeze().detach().cpu().numpy(),
#     #                       var_name=test_csa_dataset.tgt_var_name, WHETHER_STD=True)
#
#     arr_gt_csa_ori = denormalize(ds_=test_csa_dataset,
#                             arr_=arr_gt_csa.cpu().numpy(),
#                             var_name=test_csa_dataset.tgt_var_name)
#
#
#     f_mu = f_mu.squeeze().detach().cpu().numpy()
#     pred_std = f_std.squeeze().detach().cpu().numpy()
#     arr_gt_csa = arr_gt_csa.cpu().numpy()
#
#     df_pred_rst = record_prediction(denormlize_ds(test_csa_dataset, df_testset),
#                                     {'f_mu': f_mu, 'f_std': pred_std, 'normed_GT': arr_gt_csa,
#                                      'f_mu_ori': f_mu_ori, 'f_std_ori': pred_std_ori, 'normed_GT_ori': arr_gt_csa_ori})
#
#     savepath_pred_rst = os.path.join(savedir, f'pop_trend_pred_{which_set}.csv')
#     df_pred_rst.to_csv(savepath_pred_rst)
#
#     return df_pred_rst
#
#
# def eval_(dataset_name):
#     if 'Airway' in dataset_name or "AFQ" in dataset_name:
#         return partial(eval_airway, dataset_name = dataset_name)
#     else:
#         return eval_general
#
#
#
# def eval_airway(dir: str, which_set: int or str, dataset_name: str):
#     eval_general(dir, which_set)
#     filename_pred_rst = os.path.join(dir, f'pop_trend_pred_{which_set}.csv')
#     df_rst = pd.read_csv(filename_pred_rst)
#
#     # slt_percentiles = [0, 20, 40, 50, 60, 80, 100]
#     # list_pos = np.percentile(np.unique(df_rst['pos'].values), slt_percentiles)
#     CURRENT_LDMS = LANDMARKS[dataset_name]
#     list_landmark_names = list(CURRENT_LDMS.keys())
#
#     list_summary = []
#     for ith_pos in list_landmark_names: #range(len(list_pos)):
#         # f_mu = df_rst[df_rst['pos'] == list_pos[ith_pos]]['f_mu'].values
#         # f_std = df_rst[df_rst['pos'] == list_pos[ith_pos]]['f_std'].values
#         # normed_GT = df_rst[df_rst['pos'] == list_pos[ith_pos]]['normed_GT'].values
#
#         # f_mu = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.01]['f_mu'].values
#         # f_std = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.01]['f_std'].values
#         # normed_GT = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.01]['normed_GT'].values
#
#         f_mu = df_rst[(df_rst['pos'] - CURRENT_LDMS[ith_pos]).abs() <= 0.05]['f_mu_ori'].values
#         f_std = df_rst[(df_rst['pos'] - CURRENT_LDMS[ith_pos]).abs() <= 0.05]['f_std_ori'].values
#         normed_GT = df_rst[(df_rst['pos'] - CURRENT_LDMS[ith_pos]).abs() <= 0.05]['normed_GT_ori'].values
#
#
#
#         # Compute all uncertainty metrics
#         metrics = uct.metrics.get_all_metrics(f_mu, f_std, normed_GT)
#         dict_flattened = {}
#         #dict_flattened['pos'] = np.round(list_pos[ith_pos], 4)
#         dict_flattened['pos'] = np.round(CURRENT_LDMS[ith_pos], 4)
#         dict_flattened['landmark'] = ith_pos
#
#         dict_flattened.update(flatten_dicts(metrics))
#         list_summary.append(dict_flattened)
#     savepath = os.path.join(dir, f'pos_wise_summary_{which_set}.csv')
#     pd.DataFrame.from_records(list_summary).to_csv(savepath)
#
#     return
#
#
# def eval_general(dir, which_set):
#     filename_pred_rst = os.path.join(dir, f'pop_trend_pred_{which_set}.csv')
#     df_rst = pd.read_csv(filename_pred_rst)
#
#     f_mu = df_rst['f_mu'].values
#     f_std = df_rst['f_std'].values
#     normed_GT = df_rst['normed_GT'].values
#
#     # Compute all uncertainty metrics
#     metrics = uct.metrics.get_all_metrics(f_mu, f_std, normed_GT)
#     list_summary = []
#     dict_flattened = {}
#     dict_flattened.update(flatten_dicts(metrics))
#     list_summary.append(dict_flattened)
#     savepath = os.path.join(dir, f'all_summary_{which_set}.csv')
#     pd.DataFrame.from_records(list_summary).to_csv(savepath)
#     return
#
#
#
# def pred_and_eval_varying_sample_size(specs_filename: str, which_set: str, cv_idx: int=None, sample_size: int=200000):
#     specs = load_json(specs_filename, cv_idx=cv_idx, sample_size=sample_size)
#     ds_test, ds_test_dataloader = load_dataset(specs=specs, which_split=which_set)
#     trained_model = load_trained_model(specs, specs["SavedBestCheckpointPath"]) # saved checkpoint name
#
#     trained_model.eval()
#     savedir = os.path.join(specs["LoggingRoot"], f'{specs["ExperimentName"]}') #'/playpen-raid/Author/LucidAtlas/figures/v12'
#     utils.cond_mkdir(savedir)
#     dataset_name = specs["Class"]
#     test_(trained_model, ds_test, savedir, which_set)
#     eval_(dataset_name)(savedir, which_set)
#     return
#
#
#
#
#
# if __name__ == "__main__":
#     arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
#     arg_parser.add_argument(
#         "--experiment",
#         "-e",
#         dest="experiment_directory",
#         #default='/playpen-raid/Author/LucidAtlas/configs/airways/v1/airway_mlp.json',
#         #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_namlss_v1_0123_full.json',
#         #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_lucidatlas_v14_0123_part.json',
#         #default='/playpen-raid/Author/LucidAtlas/configs/airways/v1/airway_lucidatlas_v15_part.json',
#         default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_mlp_v1_0123.json', #
#         #default="/playpen-raid/Author/LucidAtlas/configs/OASISBrain/v2/brain_"
#         help="The experiment directory. This directory should include "
#              + "experiment specifications in 'specs.json', and logging will be "
#              + "done in this directory as well.",
#     )
#     arg_parser.add_argument(
#         "--checkpoint",
#         "-c",
#         dest="checkpoint",
#         default="latest",
#         help="The checkpoint weights to use. This can be a number indicated an epoch "
#         + "or 'latest' for the latest weights (this is the default)",
#     )
#
#     arg_parser.add_argument(
#         "--train",
#         dest="whether_train",
#         default=True,
#         help="whether to train from scratch",
#     )
#
#     arg_parser.add_argument(
#         "--test",
#         dest="whether_test",
#         default=True,
#         help="whether to test",
#     )
#
#     arg_parser.add_argument(
#         "--vis",
#         dest="whether_vis",
#         default=True,
#         help="whether to vis",
#     )
#
#
#     arg_parser.add_argument(
#         "--sample_size",
#         "-ss",
#         dest="sample_size",
#         default=1,
#         help="whether to vis",
#     )
#
#
#     args = arg_parser.parse_args()
#
#
#     if args.whether_test:
#         pred_and_eval_varying_smaple_size(args.experiment_directory, which_set='test',  sample_size=args.sample_size)
#     print('1')
#
#
