import sys
sys.path.append("/playpen-raid/Author/LucidAtlas")
import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams['figure.dpi']= 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os
import argparse
from pipeline.load import *
import torch.utils.data as data_utils
from utilities.eval_quant import quantitative_evaluation
from utilities.utils import record_prediction, denormalize, denormlize_ds, denormalize_from_distribution
import uncertainty_toolbox as uct
from functools import partial


def flatten_dicts(dict_eval):
    dict_flattenned = {}
    for i_k, dict_i_v in dict_eval.items():
        for i_name, i_val in dict_i_v.items():
            if isinstance(i_val, float):
                dict_flattenned[i_k + '_' + i_name] = np.round(i_val, 4)
    return dict_flattenned

def extract_queries(list_cov: list, num_of_cov: str):
    if num_of_cov == "1":
        return list_cov
    elif num_of_cov == "n-1":

        list_list_cov = []
        for ith_cov in list_cov:
            list_cov_rst = list_cov.copy()
            list_cov_rst.remove(ith_cov)
            list_list_cov.append(list_cov_rst)
        return list_list_cov


def test_(
        type_of_interpret: str,
        num_of_covariates: str,
        model_trained: torch.nn.Module,
        test_csa_dataset: torch.utils.data.Dataset,
        savedir: str,
        which_set: str):

    utils.cond_mkdir(savedir)

    test_csa_dataset.padding_muter = False
    model_input, arr_gt_csa, df_testset = \
        utils.make_input_for_eval(test_csa_dataset.DATASETNANE)(scalar_dataset=test_csa_dataset, device=model_trained.device)


    #LIST_USED_COVARIATES = test_csa_dataset.covariate_names
    LIST_USED_COVARIATES = extract_queries(list_cov=test_csa_dataset.covariate_names, num_of_cov=num_of_covariates)

    list_f_mu = []
    list_f_var = []
    list_cov_record = []
    list_data = []
    list_gt_csa = []
    model_trained = model_trained.eval()


    for ith_cov in range(len(LIST_USED_COVARIATES)):

        ## query the model
        with torch.no_grad():
            i_f_mu, i_f_var = utils.batched_infer(type_of_interpret, model_trained, model_input, LIST_USED_COVARIATES[ith_cov], batch_size=300)
        list_f_mu.append(i_f_mu)
        list_f_var.append(i_f_var)
        list_cov_record += [LIST_USED_COVARIATES[ith_cov]] * i_f_mu.shape[0]

        list_data.append(df_testset)
        list_gt_csa.append(arr_gt_csa)

    f_mu = torch.cat(list_f_mu, dim=0)
    f_var = torch.cat(list_f_var, dim=0)
    f_var = torch.clamp(f_var, min=1e-4)
    arr_gt_csa = torch.cat(list_gt_csa, dim=0)


    mh_map, sh_map = f_mu, torch.sqrt(f_var)
    f_mu_ori, pred_std_ori, low_bd_map, high_bd_map = denormalize_from_distribution(ds_=test_csa_dataset, mu=mh_map, sigma=sh_map)

    arr_gt_csa_ori = denormalize(ds_=test_csa_dataset, arr_=arr_gt_csa.cpu().numpy(), var_name=test_csa_dataset.tgt_var_name)


    f_mu = f_mu.squeeze().detach().cpu().numpy()
    pred_std = f_var.squeeze().detach().sqrt().cpu().numpy()
    arr_gt_csa = arr_gt_csa.cpu().numpy()

    df_pred_rst = record_prediction(denormlize_ds(test_csa_dataset, pd.concat(list_data, axis=0)),
                                    {'Feat': list_cov_record,
                                     'f_mu': f_mu, 'f_std': pred_std, 'normed_GT': arr_gt_csa,
                                     'f_mu_ori': f_mu_ori, 'f_std_ori': pred_std_ori, 'normed_GT_ori': arr_gt_csa_ori})



    savepath_pred_rst = os.path.join(savedir, f'{type_of_interpret}_{num_of_covariates}_pop_trend_pred_{which_set}.csv')
    df_pred_rst.to_csv(savepath_pred_rst)

    return df_pred_rst


def eval_(dataset_name):
    if 'Airway' in dataset_name or "AFQ" in dataset_name:
        return partial(eval_spatial, dataset_name = dataset_name)
    else:
    #elif dataset_name == "ADNIHP":
        return eval_general



def eval_spatial(type_of_interpret: str, num_of_covariates: str, dir: str, which_set: int or str, dataset_name: str):
    eval_general(type_of_interpret, num_of_covariates, dir, which_set)
    filename_pred_rst = os.path.join(dir, f'{type_of_interpret}_{num_of_covariates}_pop_trend_pred_{which_set}.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    CURRENT_LDMS = LANDMARKS[dataset_name]
    list_landmark_names = list(CURRENT_LDMS.keys())
    list_cov_names = np.unique(list(df_rst["Feat"].values))



    list_summary = []
    for ith_feat in list_cov_names:
        for ith_pos in list_landmark_names: #range(len(list_pos)):
            slt_data = df_rst[(df_rst['pos'] - CURRENT_LDMS[ith_pos]).abs() <= 0.05]
            slt_data = slt_data[slt_data["Feat"] == ith_feat]

            f_mu = slt_data['f_mu'].values
            f_std = slt_data['f_std'].values
            normed_GT = slt_data['normed_GT'].values

            # Compute all uncertainty metrics
            metrics = uct.metrics.get_all_metrics(f_mu, f_std, normed_GT)
            dict_flattened = {}
            dict_flattened['pos'] = np.round(CURRENT_LDMS[ith_pos], 4)
            dict_flattened['landmark'] = ith_pos
            dict_flattened['Feat'] = ith_feat

            dict_flattened.update(flatten_dicts(metrics))
            list_summary.append(dict_flattened)
    savepath = os.path.join(dir, f'{type_of_interpret}_{num_of_covariates}_pos_wise_summary_{which_set}.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)

    return


def eval_general(type_of_interpret: str, num_of_covariates: str, dir: str, which_set: str):
    filename_pred_rst = os.path.join(dir, f'{type_of_interpret}_{num_of_covariates}_pop_trend_pred_{which_set}.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    list_cov_names = np.unique(list(df_rst["Feat"].values))
    list_summary = []

    for ith_feat in list_cov_names:

        df_rst_per_cov = df_rst[df_rst["Feat"]==ith_feat]
        f_mu = df_rst_per_cov['f_mu'].values
        f_std = df_rst_per_cov['f_std'].values
        normed_GT = df_rst_per_cov['normed_GT'].values


        # Compute all uncertainty metrics
        metrics = uct.metrics.get_all_metrics(f_mu, f_std, normed_GT)

        dict_flattened = {}
        dict_flattened['Feat'] = ith_feat
        dict_flattened.update(flatten_dicts(metrics))
        list_summary.append(dict_flattened)

    savepath = os.path.join(dir, f'{type_of_interpret}_{num_of_covariates}_all_summary_{which_set}.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)
    return



def pred_and_eval_feat_interpret(specs_filename: str,
                                 which_set: str,
                                 cv_idx: int=None,
                                 type_of_interpret: str='global',
                                 num_of_covariates="1"):
    specs = load_json(specs_filename, cv_idx)

    ds_test, ds_test_dataloader = load_dataset(specs=specs, which_split=which_set)
    trained_model = load_trained_model(specs, specs["SavedBestCheckpointPath"]) # saved checkpoint name
    trained_model.eval()

    savedir = os.path.join(specs["LoggingRoot"], specs["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    dataset_name = specs["Class"]
    test_(type_of_interpret=type_of_interpret,
          num_of_covariates=num_of_covariates,
          model_trained=trained_model,
          test_csa_dataset=ds_test,
          savedir=savedir,
          which_set=which_set)
    eval_(dataset_name)(type_of_interpret=type_of_interpret,
                        num_of_covariates=num_of_covariates,
                        dir=savedir,
                        which_set=which_set)
    return





if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        default='/playpen-raid/Author/LucidAtlas/configs/airways/cv5_all/lucidatlas_full.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )
    arg_parser.add_argument(
        "--fold",
        "-f",
        dest="which_fold",
        default=2,
        help="whether to vis",
    )

    args = arg_parser.parse_args()


    if args.whether_test:
        pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                     which_set='test',
                                     cv_idx=args.which_fold,
                                     type_of_interpret='global',
                                     num_of_covariates="n-1",
                                     )

        pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                     which_set='test',
                                     cv_idx=args.which_fold,
                                     type_of_interpret='indp',
                                     num_of_covariates="n-1",
                                     )


        pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                     which_set='test',
                                     cv_idx=args.which_fold,
                                     type_of_interpret='global',
                                     num_of_covariates="1",
                                     )

        pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                     which_set='test',
                                     cv_idx=args.which_fold,
                                     type_of_interpret='indp',
                                     num_of_covariates="1",
                                     )


    print('1')


